mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 05:29:50 +08:00
test: improve unit tests for controllers.inner_api (#32203)
This commit is contained in:
parent
31eba65fe0
commit
36c1f4d506
@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
313
api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py
Normal file
313
api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py
Normal file
@ -0,0 +1,313 @@
|
||||
"""
|
||||
Unit tests for inner_api plugin endpoints
|
||||
|
||||
Tests endpoint structure (method existence) for all plugin APIs, plus
|
||||
handler-level logic tests for representative non-streaming endpoints.
|
||||
Auth/setup decorators are tested separately in test_auth_wraps.py;
|
||||
handler tests use inspect.unwrap() to bypass them.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.inner_api.plugin.plugin import (
|
||||
PluginFetchAppInfoApi,
|
||||
PluginInvokeAppApi,
|
||||
PluginInvokeEncryptApi,
|
||||
PluginInvokeLLMApi,
|
||||
PluginInvokeLLMWithStructuredOutputApi,
|
||||
PluginInvokeModerationApi,
|
||||
PluginInvokeParameterExtractorNodeApi,
|
||||
PluginInvokeQuestionClassifierNodeApi,
|
||||
PluginInvokeRerankApi,
|
||||
PluginInvokeSpeech2TextApi,
|
||||
PluginInvokeSummaryApi,
|
||||
PluginInvokeTextEmbeddingApi,
|
||||
PluginInvokeToolApi,
|
||||
PluginInvokeTTSApi,
|
||||
PluginUploadFileRequestApi,
|
||||
)
|
||||
|
||||
|
||||
def _extract_raw_post(cls):
|
||||
"""Extract the raw post() method from a plugin endpoint class.
|
||||
|
||||
Plugin endpoint methods are wrapped by several decorators (get_user_tenant,
|
||||
setup_required, plugin_inner_api_only, plugin_data). These decorators
|
||||
use @wraps where possible. This helper ensures we retrieve the original
|
||||
post(self, user_model, tenant_model, payload) function by unwrapping
|
||||
and, if necessary, walking the closure of the innermost wrapper.
|
||||
"""
|
||||
bottom = inspect.unwrap(cls.post)
|
||||
|
||||
# If unwrap() didn't get us to the raw function (e.g. if a decorator
|
||||
# missed @wraps), try to extract it from the closure if it looks like
|
||||
# a plugin_data or similar wrapper that closes over 'view_func'.
|
||||
if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars:
|
||||
try:
|
||||
idx = bottom.__code__.co_freevars.index("view_func")
|
||||
return bottom.__closure__[idx].cell_contents
|
||||
except (AttributeError, TypeError, IndexError):
|
||||
pass
|
||||
|
||||
return bottom
|
||||
|
||||
|
||||
class TestPluginInvokeLLMApi:
|
||||
"""Test PluginInvokeLLMApi endpoint structure"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeLLMApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that endpoint has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeLLMWithStructuredOutputApi:
|
||||
"""Test PluginInvokeLLMWithStructuredOutputApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeLLMWithStructuredOutputApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeTextEmbeddingApi:
|
||||
"""Test PluginInvokeTextEmbeddingApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeTextEmbeddingApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeRerankApi:
|
||||
"""Test PluginInvokeRerankApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeRerankApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeTTSApi:
|
||||
"""Test PluginInvokeTTSApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeTTSApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeSpeech2TextApi:
|
||||
"""Test PluginInvokeSpeech2TextApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeSpeech2TextApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeModerationApi:
|
||||
"""Test PluginInvokeModerationApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeModerationApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeToolApi:
|
||||
"""Test PluginInvokeToolApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeToolApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeParameterExtractorNodeApi:
|
||||
"""Test PluginInvokeParameterExtractorNodeApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeParameterExtractorNodeApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeQuestionClassifierNodeApi:
|
||||
"""Test PluginInvokeQuestionClassifierNodeApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeQuestionClassifierNodeApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeAppApi:
|
||||
"""Test PluginInvokeAppApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeAppApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeEncryptApi:
|
||||
"""Test PluginInvokeEncryptApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeEncryptApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
|
||||
def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask):
|
||||
"""Test that post() delegates to PluginEncrypter and returns model_dump output"""
|
||||
# Arrange
|
||||
mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"}
|
||||
mock_tenant = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
|
||||
# Act — extract raw post() bypassing all decorators including plugin_data
|
||||
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload)
|
||||
assert result["data"] == {"encrypted": "data"}
|
||||
assert result.get("error") == ""
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
|
||||
def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask):
|
||||
"""Test that post() catches exceptions and returns error response"""
|
||||
# Arrange
|
||||
mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed")
|
||||
mock_tenant = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
assert "encrypt failed" in result["error"]
|
||||
|
||||
|
||||
class TestPluginInvokeSummaryApi:
|
||||
"""Test PluginInvokeSummaryApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeSummaryApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginUploadFileRequestApi:
|
||||
"""Test PluginUploadFileRequestApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginUploadFileRequestApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin")
|
||||
def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask):
|
||||
"""Test that post() generates a signed URL and returns it"""
|
||||
# Arrange
|
||||
mock_get_url.return_value = "https://storage.example.com/signed-upload-url"
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-id"
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.filename = "test.pdf"
|
||||
mock_payload.mimetype = "application/pdf"
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginUploadFileRequestApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_get_url.assert_called_once_with(
|
||||
filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id"
|
||||
)
|
||||
assert result["data"]["url"] == "https://storage.example.com/signed-upload-url"
|
||||
|
||||
|
||||
class TestPluginFetchAppInfoApi:
|
||||
"""Test PluginFetchAppInfoApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginFetchAppInfoApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation")
|
||||
def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask):
|
||||
"""Test that post() fetches app info and returns it"""
|
||||
# Arrange
|
||||
mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"}
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.app_id = "app-123"
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginFetchAppInfoApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id")
|
||||
assert result["data"] == {"app_name": "My App", "mode": "chat"}
|
||||
@ -0,0 +1,305 @@
|
||||
"""
|
||||
Unit tests for inner_api plugin decorators
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.plugin.wraps import (
|
||||
TenantUserPayload,
|
||||
get_user,
|
||||
get_user_tenant,
|
||||
plugin_data,
|
||||
)
|
||||
|
||||
|
||||
class TestTenantUserPayload:
|
||||
"""Test TenantUserPayload Pydantic model"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload passes validation"""
|
||||
data = {"tenant_id": "tenant123", "user_id": "user456"}
|
||||
payload = TenantUserPayload.model_validate(data)
|
||||
assert payload.tenant_id == "tenant123"
|
||||
assert payload.user_id == "user456"
|
||||
|
||||
def test_missing_tenant_id(self):
|
||||
"""Test missing tenant_id raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
TenantUserPayload.model_validate({"user_id": "user456"})
|
||||
|
||||
def test_missing_user_id(self):
|
||||
"""Test missing user_id raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
TenantUserPayload.model_validate({"tenant_id": "tenant123"})
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test get_user function"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
||||
"""Test returning existing user when found by ID"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "user123")
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.query.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_anonymous_user_by_session_id(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test returning existing anonymous user by session_id"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.session_id = "anonymous_session"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "anonymous_session")
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
||||
"""Test creating new user when not found in database"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "user123")
|
||||
|
||||
# Assert
|
||||
assert result == mock_new_user
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_use_default_session_id_when_user_id_none(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test using default session ID when user_id is None"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", None)
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_raise_error_on_database_exception(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test raising ValueError when database operation fails"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with app.app_context():
|
||||
with pytest.raises(ValueError, match="user not found"):
|
||||
get_user("tenant123", "user123")
|
||||
|
||||
|
||||
class TestGetUserTenant:
|
||||
"""Test get_user_tenant decorator"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
"""Test that decorator injects tenant_model and user_model into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return {"tenant": tenant_model, "user": user_model}
|
||||
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant123"
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user456"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["tenant"] == mock_tenant
|
||||
assert result["user"] == mock_user
|
||||
|
||||
def test_should_raise_error_when_tenant_id_missing(self, app: Flask):
|
||||
"""Test that Pydantic ValidationError is raised when tenant_id is missing from payload"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return "success"
|
||||
|
||||
# Act & Assert - Pydantic validates payload before manual check
|
||||
with app.test_request_context(json={"user_id": "user456"}):
|
||||
with pytest.raises(ValidationError):
|
||||
protected_view()
|
||||
|
||||
def test_should_raise_error_when_tenant_not_found(self, app: Flask):
|
||||
"""Test that ValueError is raised when tenant is not found"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="tenant not found"):
|
||||
protected_view()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
"""Test that default session ID is used when user_id is empty string"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return {"tenant": tenant_model, "user": user_model}
|
||||
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant123"
|
||||
mock_user = MagicMock()
|
||||
|
||||
# Act - use empty string for user_id to trigger default logic
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["tenant"] == mock_tenant
|
||||
assert result["user"] == mock_user
|
||||
from models.model import DefaultEndUserSessionID
|
||||
|
||||
mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID)
|
||||
|
||||
|
||||
class PluginTestPayload:
|
||||
"""Simple test payload class"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.value = data.get("value")
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
return cls(data)
|
||||
|
||||
|
||||
class TestPluginData:
|
||||
"""Test plugin_data decorator"""
|
||||
|
||||
def test_should_inject_valid_payload(self, app: Flask):
|
||||
"""Test that valid payload is injected into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"value": "test_data"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result.value == "test_data"
|
||||
|
||||
def test_should_raise_error_on_invalid_json(self, app: Flask):
|
||||
"""Test that ValueError is raised when JSON parsing fails"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act & Assert - Malformed JSON triggers ValueError
|
||||
with app.test_request_context(data="not valid json", content_type="application/json"):
|
||||
with pytest.raises(ValueError):
|
||||
protected_view()
|
||||
|
||||
def test_should_raise_error_on_invalid_payload(self, app: Flask):
|
||||
"""Test that ValueError is raised when payload validation fails"""
|
||||
|
||||
# Arrange
|
||||
class InvalidPayload:
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
raise Exception("Validation failed")
|
||||
|
||||
@plugin_data(payload_type=InvalidPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"data": "test"}):
|
||||
with pytest.raises(ValueError, match="invalid payload"):
|
||||
protected_view()
|
||||
|
||||
def test_should_work_as_parameterized_decorator(self, app: Flask):
|
||||
"""Test that decorator works when used with parentheses"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"value": "parameterized"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result.value == "parameterized"
|
||||
309
api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py
Normal file
309
api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""
|
||||
Unit tests for inner_api auth decorators
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.inner_api.wraps import (
|
||||
billing_inner_api_only,
|
||||
enterprise_inner_api_only,
|
||||
enterprise_inner_api_user_auth,
|
||||
plugin_inner_api_only,
|
||||
)
|
||||
|
||||
|
||||
class TestBillingInnerApiOnly:
|
||||
"""Test billing_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when INNER_API is enabled"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that 404 is returned when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiOnly:
|
||||
"""Test enterprise_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when INNER_API is enabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that 404 is returned when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiUserAuth:
|
||||
"""Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication"""
|
||||
|
||||
def test_should_pass_through_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that request passes through when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_authorization_header_missing(self, app: Flask):
|
||||
"""Test that request passes through when Authorization header is missing"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_authorization_format_invalid(self, app: Flask):
|
||||
"""Test that request passes through when Authorization format is invalid (no colon)"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"Authorization": "invalid_format"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask):
|
||||
"""Test that request passes through when HMAC signature is invalid"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act - use wrong signature
|
||||
with app.test_request_context(
|
||||
headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_inject_user_when_hmac_signature_valid(self, app: Flask):
|
||||
"""Test that user is injected when HMAC signature is valid"""
|
||||
# Arrange
|
||||
from base64 import b64encode
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user")
|
||||
|
||||
# Calculate valid HMAC signature
|
||||
user_id = "user123"
|
||||
inner_api_key = "valid_key"
|
||||
data_to_sign = f"DIFY {user_id}"
|
||||
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
|
||||
valid_signature = b64encode(signature.digest()).decode("utf-8")
|
||||
|
||||
# Create mock user
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
|
||||
class TestPluginInnerApiOnly:
|
||||
"""Test plugin_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when PLUGIN_DAEMON_KEY is set"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}):
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
|
||||
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask):
|
||||
"""Test that 404 is returned when PLUGIN_DAEMON_KEY is not set"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_404_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
|
||||
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
206
api/tests/unit_tests/controllers/inner_api/test_mail.py
Normal file
206
api/tests/unit_tests/controllers/inner_api/test_mail.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""
|
||||
Unit tests for inner_api mail module
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.mail import (
|
||||
BaseMail,
|
||||
BillingMail,
|
||||
EnterpriseMail,
|
||||
InnerMailPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestInnerMailPayload:
|
||||
"""Test InnerMailPayload Pydantic model"""
|
||||
|
||||
def test_valid_payload_with_all_fields(self):
|
||||
"""Test valid payload with all fields passes validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
"substitutions": {"key": "value"},
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert payload.to == ["test@example.com"]
|
||||
assert payload.subject == "Test Subject"
|
||||
assert payload.body == "Test Body"
|
||||
assert payload.substitutions == {"key": "value"}
|
||||
|
||||
def test_valid_payload_without_substitutions(self):
|
||||
"""Test valid payload without optional substitutions"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert payload.to == ["test@example.com"]
|
||||
assert payload.subject == "Test Subject"
|
||||
assert payload.body == "Test Body"
|
||||
assert payload.substitutions is None
|
||||
|
||||
def test_empty_to_list_fails_validation(self):
|
||||
"""Test that empty 'to' list fails validation due to min_length=1"""
|
||||
data = {
|
||||
"to": [],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_multiple_recipients_allowed(self):
|
||||
"""Test that multiple recipients are allowed"""
|
||||
data = {
|
||||
"to": ["user1@example.com", "user2@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert len(payload.to) == 2
|
||||
assert "user1@example.com" in payload.to
|
||||
assert "user2@example.com" in payload.to
|
||||
|
||||
def test_missing_to_field_fails_validation(self):
|
||||
"""Test that missing 'to' field fails validation"""
|
||||
data = {
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_missing_subject_fails_validation(self):
|
||||
"""Test that missing 'subject' field fails validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_missing_body_fails_validation(self):
|
||||
"""Test that missing 'body' field fails validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
|
||||
class TestBaseMail:
|
||||
"""Test BaseMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create BaseMail API instance"""
|
||||
return BaseMail()
|
||||
|
||||
@patch("controllers.inner_api.mail.send_inner_email_task")
|
||||
def test_post_sends_email_task(self, mock_task, api_instance, app: Flask):
|
||||
"""Test that POST sends inner email task"""
|
||||
# Arrange
|
||||
mock_task.delay.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
json={
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
):
|
||||
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
result = api_instance.post()
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "success"}, 200)
|
||||
mock_task.delay.assert_called_once_with(
|
||||
to=["test@example.com"],
|
||||
subject="Test Subject",
|
||||
body="Test Body",
|
||||
substitutions=None,
|
||||
)
|
||||
|
||||
@patch("controllers.inner_api.mail.send_inner_email_task")
|
||||
def test_post_with_substitutions(self, mock_task, api_instance, app: Flask):
|
||||
"""Test that POST sends email with substitutions"""
|
||||
# Arrange
|
||||
mock_task.delay.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Hello {{name}}",
|
||||
"body": "Welcome {{name}}!",
|
||||
"substitutions": {"name": "John"},
|
||||
}
|
||||
result = api_instance.post()
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "success"}, 200)
|
||||
mock_task.delay.assert_called_once_with(
|
||||
to=["test@example.com"],
|
||||
subject="Hello {{name}}",
|
||||
body="Welcome {{name}}!",
|
||||
substitutions={"name": "John"},
|
||||
)
|
||||
|
||||
|
||||
class TestEnterpriseMail:
|
||||
"""Test EnterpriseMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create EnterpriseMail API instance"""
|
||||
return EnterpriseMail()
|
||||
|
||||
def test_has_enterprise_inner_api_only_decorator(self, api_instance):
|
||||
"""Test that EnterpriseMail has enterprise_inner_api_only decorator"""
|
||||
# Check method_decorators
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
|
||||
assert enterprise_inner_api_only in api_instance.method_decorators
|
||||
|
||||
def test_has_setup_required_decorator(self, api_instance):
|
||||
"""Test that EnterpriseMail has setup_required decorator"""
|
||||
# Check by decorator name instead of object reference
|
||||
decorator_names = [d.__name__ for d in api_instance.method_decorators]
|
||||
assert "setup_required" in decorator_names
|
||||
|
||||
|
||||
class TestBillingMail:
|
||||
"""Test BillingMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create BillingMail API instance"""
|
||||
return BillingMail()
|
||||
|
||||
def test_has_billing_inner_api_only_decorator(self, api_instance):
|
||||
"""Test that BillingMail has billing_inner_api_only decorator"""
|
||||
# Check method_decorators
|
||||
from controllers.inner_api.wraps import billing_inner_api_only
|
||||
|
||||
assert billing_inner_api_only in api_instance.method_decorators
|
||||
|
||||
def test_has_setup_required_decorator(self, api_instance):
|
||||
"""Test that BillingMail has setup_required decorator"""
|
||||
# Check by decorator name instead of object reference
|
||||
decorator_names = [d.__name__ for d in api_instance.method_decorators]
|
||||
assert "setup_required" in decorator_names
|
||||
@ -0,0 +1,184 @@
|
||||
"""
|
||||
Unit tests for inner_api workspace module
|
||||
|
||||
Tests Pydantic model validation and endpoint handler logic.
|
||||
Auth/setup decorators are tested separately in test_auth_wraps.py;
|
||||
handler tests use inspect.unwrap() to bypass them and focus on business logic.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.workspace.workspace import (
|
||||
EnterpriseWorkspace,
|
||||
EnterpriseWorkspaceNoOwnerEmail,
|
||||
WorkspaceCreatePayload,
|
||||
WorkspaceOwnerlessPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkspaceCreatePayload:
|
||||
"""Test WorkspaceCreatePayload Pydantic model validation"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload with all fields passes validation"""
|
||||
data = {
|
||||
"name": "My Workspace",
|
||||
"owner_email": "owner@example.com",
|
||||
}
|
||||
payload = WorkspaceCreatePayload.model_validate(data)
|
||||
assert payload.name == "My Workspace"
|
||||
assert payload.owner_email == "owner@example.com"
|
||||
|
||||
def test_missing_name_fails_validation(self):
|
||||
"""Test that missing name fails validation"""
|
||||
data = {"owner_email": "owner@example.com"}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceCreatePayload.model_validate(data)
|
||||
assert "name" in str(exc_info.value)
|
||||
|
||||
def test_missing_owner_email_fails_validation(self):
|
||||
"""Test that missing owner_email fails validation"""
|
||||
data = {"name": "My Workspace"}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceCreatePayload.model_validate(data)
|
||||
assert "owner_email" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestWorkspaceOwnerlessPayload:
|
||||
"""Test WorkspaceOwnerlessPayload Pydantic model validation"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload with name passes validation"""
|
||||
data = {"name": "My Workspace"}
|
||||
payload = WorkspaceOwnerlessPayload.model_validate(data)
|
||||
assert payload.name == "My Workspace"
|
||||
|
||||
def test_missing_name_fails_validation(self):
|
||||
"""Test that missing name fails validation"""
|
||||
data = {}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceOwnerlessPayload.model_validate(data)
|
||||
assert "name" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEnterpriseWorkspace:
|
||||
"""Test EnterpriseWorkspace API endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
and exercise the core business logic directly.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseWorkspace()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that EnterpriseWorkspace has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
|
||||
@patch("controllers.inner_api.workspace.workspace.TenantService")
|
||||
@patch("controllers.inner_api.workspace.workspace.db")
|
||||
def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask):
|
||||
"""Test that post() creates a workspace and assigns the owner account"""
|
||||
# Arrange
|
||||
mock_account = MagicMock()
|
||||
mock_account.email = "owner@example.com"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_tenant.name = "My Workspace"
|
||||
mock_tenant.plan = "sandbox"
|
||||
mock_tenant.status = "normal"
|
||||
mock_tenant.created_at = now
|
||||
mock_tenant.updated_at = now
|
||||
mock_tenant_svc.create_tenant.return_value = mock_tenant
|
||||
|
||||
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["name"] == "My Workspace"
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.db")
|
||||
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
|
||||
"""Test that post() returns 404 when the owner account does not exist"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "owner account not found."}, 404)
|
||||
|
||||
|
||||
class TestEnterpriseWorkspaceNoOwnerEmail:
|
||||
"""Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
and exercise the core business logic directly.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseWorkspaceNoOwnerEmail()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that endpoint has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
|
||||
@patch("controllers.inner_api.workspace.workspace.TenantService")
|
||||
def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask):
|
||||
"""Test that post() creates a workspace without an owner and returns expected fields"""
|
||||
# Arrange
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_tenant.name = "My Workspace"
|
||||
mock_tenant.encrypt_public_key = "pub-key"
|
||||
mock_tenant.plan = "sandbox"
|
||||
mock_tenant.status = "normal"
|
||||
mock_tenant.custom_config = None
|
||||
mock_tenant.created_at = now
|
||||
mock_tenant.updated_at = now
|
||||
mock_tenant_svc.create_tenant.return_value = mock_tenant
|
||||
|
||||
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["encrypt_public_key"] == "pub-key"
|
||||
assert result["tenant"]["custom_config"] == {}
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
Loading…
Reference in New Issue
Block a user