test: improve unit tests for controllers.inner_api (#32203)

This commit is contained in:
Dev Sharma 2026-03-12 08:37:56 +05:30 committed by GitHub
parent 31eba65fe0
commit 36c1f4d506
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1318 additions and 0 deletions

View File

@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]):
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try:
data = request.get_json()

View File

@ -0,0 +1,313 @@
"""
Unit tests for inner_api plugin endpoints
Tests endpoint structure (method existence) for all plugin APIs, plus
handler-level logic tests for representative non-streaming endpoints.
Auth/setup decorators are tested separately in test_auth_wraps.py;
handler tests use inspect.unwrap() to bypass them.
"""
import inspect
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.inner_api.plugin.plugin import (
PluginFetchAppInfoApi,
PluginInvokeAppApi,
PluginInvokeEncryptApi,
PluginInvokeLLMApi,
PluginInvokeLLMWithStructuredOutputApi,
PluginInvokeModerationApi,
PluginInvokeParameterExtractorNodeApi,
PluginInvokeQuestionClassifierNodeApi,
PluginInvokeRerankApi,
PluginInvokeSpeech2TextApi,
PluginInvokeSummaryApi,
PluginInvokeTextEmbeddingApi,
PluginInvokeToolApi,
PluginInvokeTTSApi,
PluginUploadFileRequestApi,
)
def _extract_raw_post(cls):
"""Extract the raw post() method from a plugin endpoint class.
Plugin endpoint methods are wrapped by several decorators (get_user_tenant,
setup_required, plugin_inner_api_only, plugin_data). These decorators
use @wraps where possible. This helper ensures we retrieve the original
post(self, user_model, tenant_model, payload) function by unwrapping
and, if necessary, walking the closure of the innermost wrapper.
"""
bottom = inspect.unwrap(cls.post)
# If unwrap() didn't get us to the raw function (e.g. if a decorator
# missed @wraps), try to extract it from the closure if it looks like
# a plugin_data or similar wrapper that closes over 'view_func'.
if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars:
try:
idx = bottom.__code__.co_freevars.index("view_func")
return bottom.__closure__[idx].cell_contents
except (AttributeError, TypeError, IndexError):
pass
return bottom
class TestPluginInvokeLLMApi:
"""Test PluginInvokeLLMApi endpoint structure"""
@pytest.fixture
def api_instance(self):
return PluginInvokeLLMApi()
def test_has_post_method(self, api_instance):
"""Test that endpoint has post method"""
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeLLMWithStructuredOutputApi:
"""Test PluginInvokeLLMWithStructuredOutputApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeLLMWithStructuredOutputApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeTextEmbeddingApi:
"""Test PluginInvokeTextEmbeddingApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeTextEmbeddingApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeRerankApi:
"""Test PluginInvokeRerankApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeRerankApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeTTSApi:
"""Test PluginInvokeTTSApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeTTSApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeSpeech2TextApi:
"""Test PluginInvokeSpeech2TextApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeSpeech2TextApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeModerationApi:
"""Test PluginInvokeModerationApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeModerationApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeToolApi:
"""Test PluginInvokeToolApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeToolApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeParameterExtractorNodeApi:
"""Test PluginInvokeParameterExtractorNodeApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeParameterExtractorNodeApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeQuestionClassifierNodeApi:
"""Test PluginInvokeQuestionClassifierNodeApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeQuestionClassifierNodeApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeAppApi:
"""Test PluginInvokeAppApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeAppApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginInvokeEncryptApi:
"""Test PluginInvokeEncryptApi endpoint structure and handler logic"""
@pytest.fixture
def api_instance(self):
return PluginInvokeEncryptApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask):
"""Test that post() delegates to PluginEncrypter and returns model_dump output"""
# Arrange
mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"}
mock_tenant = MagicMock()
mock_user = MagicMock()
mock_payload = MagicMock()
# Act — extract raw post() bypassing all decorators including plugin_data
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
# Assert
mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload)
assert result["data"] == {"encrypted": "data"}
assert result.get("error") == ""
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask):
"""Test that post() catches exceptions and returns error response"""
# Arrange
mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed")
mock_tenant = MagicMock()
mock_user = MagicMock()
mock_payload = MagicMock()
# Act
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
# Assert
assert "encrypt failed" in result["error"]
class TestPluginInvokeSummaryApi:
"""Test PluginInvokeSummaryApi endpoint"""
@pytest.fixture
def api_instance(self):
return PluginInvokeSummaryApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
class TestPluginUploadFileRequestApi:
"""Test PluginUploadFileRequestApi endpoint structure and handler logic"""
@pytest.fixture
def api_instance(self):
return PluginUploadFileRequestApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin")
def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask):
"""Test that post() generates a signed URL and returns it"""
# Arrange
mock_get_url.return_value = "https://storage.example.com/signed-upload-url"
mock_tenant = MagicMock()
mock_tenant.id = "tenant-id"
mock_user = MagicMock()
mock_user.id = "user-id"
mock_payload = MagicMock()
mock_payload.filename = "test.pdf"
mock_payload.mimetype = "application/pdf"
# Act
raw_post = _extract_raw_post(PluginUploadFileRequestApi)
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
# Assert
mock_get_url.assert_called_once_with(
filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id"
)
assert result["data"]["url"] == "https://storage.example.com/signed-upload-url"
class TestPluginFetchAppInfoApi:
"""Test PluginFetchAppInfoApi endpoint structure and handler logic"""
@pytest.fixture
def api_instance(self):
return PluginFetchAppInfoApi()
def test_has_post_method(self, api_instance):
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation")
def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask):
"""Test that post() fetches app info and returns it"""
# Arrange
mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"}
mock_tenant = MagicMock()
mock_tenant.id = "tenant-id"
mock_user = MagicMock()
mock_payload = MagicMock()
mock_payload.app_id = "app-123"
# Act
raw_post = _extract_raw_post(PluginFetchAppInfoApi)
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
# Assert
mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id")
assert result["data"] == {"app_name": "My App", "mode": "chat"}

View File

@ -0,0 +1,305 @@
"""
Unit tests for inner_api plugin decorators
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.inner_api.plugin.wraps import (
TenantUserPayload,
get_user,
get_user_tenant,
plugin_data,
)
class TestTenantUserPayload:
"""Test TenantUserPayload Pydantic model"""
def test_valid_payload(self):
"""Test valid payload passes validation"""
data = {"tenant_id": "tenant123", "user_id": "user456"}
payload = TenantUserPayload.model_validate(data)
assert payload.tenant_id == "tenant123"
assert payload.user_id == "user456"
def test_missing_tenant_id(self):
"""Test missing tenant_id raises ValidationError"""
with pytest.raises(ValidationError):
TenantUserPayload.model_validate({"user_id": "user456"})
def test_missing_user_id(self):
"""Test missing user_id raises ValidationError"""
with pytest.raises(ValidationError):
TenantUserPayload.model_validate({"tenant_id": "tenant123"})
class TestGetUser:
"""Test get_user function"""
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
"""Test returning existing user when found by ID"""
# Arrange
mock_user = MagicMock()
mock_user.id = "user123"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# Act
with app.app_context():
result = get_user("tenant123", "user123")
# Assert
assert result == mock_user
mock_session.query.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_return_existing_anonymous_user_by_session_id(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
):
"""Test returning existing anonymous user by session_id"""
# Arrange
mock_user = MagicMock()
mock_user.session_id = "anonymous_session"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# Act
with app.app_context():
result = get_user("tenant123", "anonymous_session")
# Assert
assert result == mock_user
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
"""Test creating new user when not found in database"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = None
mock_new_user = MagicMock()
mock_enduser_class.return_value = mock_new_user
# Act
with app.app_context():
result = get_user("tenant123", "user123")
# Assert
assert result == mock_new_user
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_use_default_session_id_when_user_id_none(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
):
"""Test using default session ID when user_id is None"""
# Arrange
mock_user = MagicMock()
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.where.return_value.first.return_value = mock_user
# Act
with app.app_context():
result = get_user("tenant123", None)
# Assert
assert result == mock_user
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.Session")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_raise_error_on_database_exception(
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
):
"""Test raising ValueError when database operation fails"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.query.side_effect = Exception("Database error")
# Act & Assert
with app.app_context():
with pytest.raises(ValueError, match="user not found"):
get_user("tenant123", "user123")
class TestGetUserTenant:
"""Test get_user_tenant decorator"""
@patch("controllers.inner_api.plugin.wraps.Tenant")
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch):
"""Test that decorator injects tenant_model and user_model into kwargs"""
# Arrange
@get_user_tenant
def protected_view(tenant_model, user_model, **kwargs):
return {"tenant": tenant_model, "user": user_model}
mock_tenant = MagicMock()
mock_tenant.id = "tenant123"
mock_user = MagicMock()
mock_user.id = "user456"
# Act
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_get_user.return_value = mock_user
result = protected_view()
# Assert
assert result["tenant"] == mock_tenant
assert result["user"] == mock_user
def test_should_raise_error_when_tenant_id_missing(self, app: Flask):
"""Test that Pydantic ValidationError is raised when tenant_id is missing from payload"""
# Arrange
@get_user_tenant
def protected_view(tenant_model, user_model, **kwargs):
return "success"
# Act & Assert - Pydantic validates payload before manual check
with app.test_request_context(json={"user_id": "user456"}):
with pytest.raises(ValidationError):
protected_view()
def test_should_raise_error_when_tenant_not_found(self, app: Flask):
"""Test that ValueError is raised when tenant is not found"""
# Arrange
@get_user_tenant
def protected_view(tenant_model, user_model, **kwargs):
return "success"
# Act & Assert
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="tenant not found"):
protected_view()
@patch("controllers.inner_api.plugin.wraps.Tenant")
def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch):
"""Test that default session ID is used when user_id is empty string"""
# Arrange
@get_user_tenant
def protected_view(tenant_model, user_model, **kwargs):
return {"tenant": tenant_model, "user": user_model}
mock_tenant = MagicMock()
mock_tenant.id = "tenant123"
mock_user = MagicMock()
# Act - use empty string for user_id to trigger default logic
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_get_user.return_value = mock_user
result = protected_view()
# Assert
assert result["tenant"] == mock_tenant
assert result["user"] == mock_user
from models.model import DefaultEndUserSessionID
mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID)
class PluginTestPayload:
"""Simple test payload class"""
def __init__(self, data: dict):
self.value = data.get("value")
@classmethod
def model_validate(cls, data: dict):
return cls(data)
class TestPluginData:
"""Test plugin_data decorator"""
def test_should_inject_valid_payload(self, app: Flask):
"""Test that valid payload is injected into kwargs"""
# Arrange
@plugin_data(payload_type=PluginTestPayload)
def protected_view(payload, **kwargs):
return payload
# Act
with app.test_request_context(json={"value": "test_data"}):
result = protected_view()
# Assert
assert result.value == "test_data"
def test_should_raise_error_on_invalid_json(self, app: Flask):
"""Test that ValueError is raised when JSON parsing fails"""
# Arrange
@plugin_data(payload_type=PluginTestPayload)
def protected_view(payload, **kwargs):
return payload
# Act & Assert - Malformed JSON triggers ValueError
with app.test_request_context(data="not valid json", content_type="application/json"):
with pytest.raises(ValueError):
protected_view()
def test_should_raise_error_on_invalid_payload(self, app: Flask):
"""Test that ValueError is raised when payload validation fails"""
# Arrange
class InvalidPayload:
@classmethod
def model_validate(cls, data: dict):
raise Exception("Validation failed")
@plugin_data(payload_type=InvalidPayload)
def protected_view(payload, **kwargs):
return payload
# Act & Assert
with app.test_request_context(json={"data": "test"}):
with pytest.raises(ValueError, match="invalid payload"):
protected_view()
def test_should_work_as_parameterized_decorator(self, app: Flask):
"""Test that decorator works when used with parentheses"""
# Arrange
@plugin_data(payload_type=PluginTestPayload)
def protected_view(payload, **kwargs):
return payload
# Act
with app.test_request_context(json={"value": "parameterized"}):
result = protected_view()
# Assert
assert result.value == "parameterized"

View File

@ -0,0 +1,309 @@
"""
Unit tests for inner_api auth decorators
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import HTTPException
from configs import dify_config
from controllers.inner_api.wraps import (
billing_inner_api_only,
enterprise_inner_api_only,
enterprise_inner_api_user_auth,
plugin_inner_api_only,
)
class TestBillingInnerApiOnly:
"""Test billing_inner_api_only decorator"""
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
"""Test that valid API key allows access when INNER_API is enabled"""
# Arrange
@billing_inner_api_only
def protected_view():
return "success"
# Act
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
result = protected_view()
# Assert
assert result == "success"
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
"""Test that 404 is returned when INNER_API is disabled"""
# Arrange
@billing_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context():
with patch.object(dify_config, "INNER_API", False):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 404
def test_should_return_401_when_api_key_missing(self, app: Flask):
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
# Arrange
@billing_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 401
def test_should_return_401_when_api_key_invalid(self, app: Flask):
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
# Arrange
@billing_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 401
class TestEnterpriseInnerApiOnly:
"""Test enterprise_inner_api_only decorator"""
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
"""Test that valid API key allows access when INNER_API is enabled"""
# Arrange
@enterprise_inner_api_only
def protected_view():
return "success"
# Act
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
result = protected_view()
# Assert
assert result == "success"
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
"""Test that 404 is returned when INNER_API is disabled"""
# Arrange
@enterprise_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context():
with patch.object(dify_config, "INNER_API", False):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 404
def test_should_return_401_when_api_key_missing(self, app: Flask):
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
# Arrange
@enterprise_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 401
def test_should_return_401_when_api_key_invalid(self, app: Flask):
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
# Arrange
@enterprise_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
with patch.object(dify_config, "INNER_API", True):
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 401
class TestEnterpriseInnerApiUserAuth:
"""Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication"""
def test_should_pass_through_when_inner_api_disabled(self, app: Flask):
"""Test that request passes through when INNER_API is disabled"""
# Arrange
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user", "no_user")
# Act
with app.test_request_context():
with patch.object(dify_config, "INNER_API", False):
result = protected_view()
# Assert
assert result == "no_user"
def test_should_pass_through_when_authorization_header_missing(self, app: Flask):
"""Test that request passes through when Authorization header is missing"""
# Arrange
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user", "no_user")
# Act
with app.test_request_context(headers={}):
with patch.object(dify_config, "INNER_API", True):
result = protected_view()
# Assert
assert result == "no_user"
def test_should_pass_through_when_authorization_format_invalid(self, app: Flask):
"""Test that request passes through when Authorization format is invalid (no colon)"""
# Arrange
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user", "no_user")
# Act
with app.test_request_context(headers={"Authorization": "invalid_format"}):
with patch.object(dify_config, "INNER_API", True):
result = protected_view()
# Assert
assert result == "no_user"
def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask):
"""Test that request passes through when HMAC signature is invalid"""
# Arrange
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user", "no_user")
# Act - use wrong signature
with app.test_request_context(
headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"}
):
with patch.object(dify_config, "INNER_API", True):
result = protected_view()
# Assert
assert result == "no_user"
def test_should_inject_user_when_hmac_signature_valid(self, app: Flask):
"""Test that user is injected when HMAC signature is valid"""
# Arrange
from base64 import b64encode
from hashlib import sha1
from hmac import new as hmac_new
@enterprise_inner_api_user_auth
def protected_view(**kwargs):
return kwargs.get("user")
# Calculate valid HMAC signature
user_id = "user123"
inner_api_key = "valid_key"
data_to_sign = f"DIFY {user_id}"
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
valid_signature = b64encode(signature.digest()).decode("utf-8")
# Create mock user
mock_user = MagicMock()
mock_user.id = user_id
# Act
with app.test_request_context(
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
):
with patch.object(dify_config, "INNER_API", True):
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = mock_user
result = protected_view()
# Assert
assert result == mock_user
class TestPluginInnerApiOnly:
"""Test plugin_inner_api_only decorator"""
def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask):
"""Test that valid API key allows access when PLUGIN_DAEMON_KEY is set"""
# Arrange
@plugin_inner_api_only
def protected_view():
return "success"
# Act
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}):
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
result = protected_view()
# Assert
assert result == "success"
def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask):
"""Test that 404 is returned when PLUGIN_DAEMON_KEY is not set"""
# Arrange
@plugin_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context():
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 404
def test_should_return_404_when_api_key_invalid(self, app: Flask):
"""Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)"""
# Arrange
@plugin_inner_api_only
def protected_view():
return "success"
# Act & Assert
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
with pytest.raises(HTTPException) as exc_info:
protected_view()
assert exc_info.value.code == 404

View File

@ -0,0 +1,206 @@
"""
Unit tests for inner_api mail module
"""
from unittest.mock import patch
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.inner_api.mail import (
BaseMail,
BillingMail,
EnterpriseMail,
InnerMailPayload,
)
class TestInnerMailPayload:
"""Test InnerMailPayload Pydantic model"""
def test_valid_payload_with_all_fields(self):
"""Test valid payload with all fields passes validation"""
data = {
"to": ["test@example.com"],
"subject": "Test Subject",
"body": "Test Body",
"substitutions": {"key": "value"},
}
payload = InnerMailPayload.model_validate(data)
assert payload.to == ["test@example.com"]
assert payload.subject == "Test Subject"
assert payload.body == "Test Body"
assert payload.substitutions == {"key": "value"}
def test_valid_payload_without_substitutions(self):
"""Test valid payload without optional substitutions"""
data = {
"to": ["test@example.com"],
"subject": "Test Subject",
"body": "Test Body",
}
payload = InnerMailPayload.model_validate(data)
assert payload.to == ["test@example.com"]
assert payload.subject == "Test Subject"
assert payload.body == "Test Body"
assert payload.substitutions is None
def test_empty_to_list_fails_validation(self):
"""Test that empty 'to' list fails validation due to min_length=1"""
data = {
"to": [],
"subject": "Test Subject",
"body": "Test Body",
}
with pytest.raises(ValidationError):
InnerMailPayload.model_validate(data)
def test_multiple_recipients_allowed(self):
"""Test that multiple recipients are allowed"""
data = {
"to": ["user1@example.com", "user2@example.com"],
"subject": "Test Subject",
"body": "Test Body",
}
payload = InnerMailPayload.model_validate(data)
assert len(payload.to) == 2
assert "user1@example.com" in payload.to
assert "user2@example.com" in payload.to
def test_missing_to_field_fails_validation(self):
"""Test that missing 'to' field fails validation"""
data = {
"subject": "Test Subject",
"body": "Test Body",
}
with pytest.raises(ValidationError):
InnerMailPayload.model_validate(data)
def test_missing_subject_fails_validation(self):
"""Test that missing 'subject' field fails validation"""
data = {
"to": ["test@example.com"],
"body": "Test Body",
}
with pytest.raises(ValidationError):
InnerMailPayload.model_validate(data)
def test_missing_body_fails_validation(self):
"""Test that missing 'body' field fails validation"""
data = {
"to": ["test@example.com"],
"subject": "Test Subject",
}
with pytest.raises(ValidationError):
InnerMailPayload.model_validate(data)
class TestBaseMail:
"""Test BaseMail API endpoint"""
@pytest.fixture
def api_instance(self):
"""Create BaseMail API instance"""
return BaseMail()
@patch("controllers.inner_api.mail.send_inner_email_task")
def test_post_sends_email_task(self, mock_task, api_instance, app: Flask):
"""Test that POST sends inner email task"""
# Arrange
mock_task.delay.return_value = None
# Act
with app.test_request_context(
json={
"to": ["test@example.com"],
"subject": "Test Subject",
"body": "Test Body",
}
):
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
mock_ns.payload = {
"to": ["test@example.com"],
"subject": "Test Subject",
"body": "Test Body",
}
result = api_instance.post()
# Assert
assert result == ({"message": "success"}, 200)
mock_task.delay.assert_called_once_with(
to=["test@example.com"],
subject="Test Subject",
body="Test Body",
substitutions=None,
)
@patch("controllers.inner_api.mail.send_inner_email_task")
def test_post_with_substitutions(self, mock_task, api_instance, app: Flask):
"""Test that POST sends email with substitutions"""
# Arrange
mock_task.delay.return_value = None
# Act
with app.test_request_context():
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
mock_ns.payload = {
"to": ["test@example.com"],
"subject": "Hello {{name}}",
"body": "Welcome {{name}}!",
"substitutions": {"name": "John"},
}
result = api_instance.post()
# Assert
assert result == ({"message": "success"}, 200)
mock_task.delay.assert_called_once_with(
to=["test@example.com"],
subject="Hello {{name}}",
body="Welcome {{name}}!",
substitutions={"name": "John"},
)
class TestEnterpriseMail:
"""Test EnterpriseMail API endpoint"""
@pytest.fixture
def api_instance(self):
"""Create EnterpriseMail API instance"""
return EnterpriseMail()
def test_has_enterprise_inner_api_only_decorator(self, api_instance):
"""Test that EnterpriseMail has enterprise_inner_api_only decorator"""
# Check method_decorators
from controllers.inner_api.wraps import enterprise_inner_api_only
assert enterprise_inner_api_only in api_instance.method_decorators
def test_has_setup_required_decorator(self, api_instance):
"""Test that EnterpriseMail has setup_required decorator"""
# Check by decorator name instead of object reference
decorator_names = [d.__name__ for d in api_instance.method_decorators]
assert "setup_required" in decorator_names
class TestBillingMail:
"""Test BillingMail API endpoint"""
@pytest.fixture
def api_instance(self):
"""Create BillingMail API instance"""
return BillingMail()
def test_has_billing_inner_api_only_decorator(self, api_instance):
"""Test that BillingMail has billing_inner_api_only decorator"""
# Check method_decorators
from controllers.inner_api.wraps import billing_inner_api_only
assert billing_inner_api_only in api_instance.method_decorators
def test_has_setup_required_decorator(self, api_instance):
"""Test that BillingMail has setup_required decorator"""
# Check by decorator name instead of object reference
decorator_names = [d.__name__ for d in api_instance.method_decorators]
assert "setup_required" in decorator_names

View File

@ -0,0 +1,184 @@
"""
Unit tests for inner_api workspace module
Tests Pydantic model validation and endpoint handler logic.
Auth/setup decorators are tested separately in test_auth_wraps.py;
handler tests use inspect.unwrap() to bypass them and focus on business logic.
"""
import inspect
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.inner_api.workspace.workspace import (
EnterpriseWorkspace,
EnterpriseWorkspaceNoOwnerEmail,
WorkspaceCreatePayload,
WorkspaceOwnerlessPayload,
)
class TestWorkspaceCreatePayload:
"""Test WorkspaceCreatePayload Pydantic model validation"""
def test_valid_payload(self):
"""Test valid payload with all fields passes validation"""
data = {
"name": "My Workspace",
"owner_email": "owner@example.com",
}
payload = WorkspaceCreatePayload.model_validate(data)
assert payload.name == "My Workspace"
assert payload.owner_email == "owner@example.com"
def test_missing_name_fails_validation(self):
"""Test that missing name fails validation"""
data = {"owner_email": "owner@example.com"}
with pytest.raises(ValidationError) as exc_info:
WorkspaceCreatePayload.model_validate(data)
assert "name" in str(exc_info.value)
def test_missing_owner_email_fails_validation(self):
"""Test that missing owner_email fails validation"""
data = {"name": "My Workspace"}
with pytest.raises(ValidationError) as exc_info:
WorkspaceCreatePayload.model_validate(data)
assert "owner_email" in str(exc_info.value)
class TestWorkspaceOwnerlessPayload:
"""Test WorkspaceOwnerlessPayload Pydantic model validation"""
def test_valid_payload(self):
"""Test valid payload with name passes validation"""
data = {"name": "My Workspace"}
payload = WorkspaceOwnerlessPayload.model_validate(data)
assert payload.name == "My Workspace"
def test_missing_name_fails_validation(self):
"""Test that missing name fails validation"""
data = {}
with pytest.raises(ValidationError) as exc_info:
WorkspaceOwnerlessPayload.model_validate(data)
assert "name" in str(exc_info.value)
class TestEnterpriseWorkspace:
"""Test EnterpriseWorkspace API endpoint handler logic.
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
and exercise the core business logic directly.
"""
@pytest.fixture
def api_instance(self):
return EnterpriseWorkspace()
def test_has_post_method(self, api_instance):
"""Test that EnterpriseWorkspace has post method"""
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
@patch("controllers.inner_api.workspace.workspace.TenantService")
@patch("controllers.inner_api.workspace.workspace.db")
def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask):
"""Test that post() creates a workspace and assigns the owner account"""
# Arrange
mock_account = MagicMock()
mock_account.email = "owner@example.com"
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
now = datetime(2025, 1, 1, 12, 0, 0)
mock_tenant = MagicMock()
mock_tenant.id = "tenant-id"
mock_tenant.name = "My Workspace"
mock_tenant.plan = "sandbox"
mock_tenant.status = "normal"
mock_tenant.created_at = now
mock_tenant.updated_at = now
mock_tenant_svc.create_tenant.return_value = mock_tenant
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
unwrapped_post = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"}
result = unwrapped_post(api_instance)
# Assert
assert result["message"] == "enterprise workspace created."
assert result["tenant"]["id"] == "tenant-id"
assert result["tenant"]["name"] == "My Workspace"
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
mock_event.send.assert_called_once_with(mock_tenant)
@patch("controllers.inner_api.workspace.workspace.db")
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
"""Test that post() returns 404 when the owner account does not exist"""
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act
unwrapped_post = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"}
result = unwrapped_post(api_instance)
# Assert
assert result == ({"message": "owner account not found."}, 404)
class TestEnterpriseWorkspaceNoOwnerEmail:
"""Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic.
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
and exercise the core business logic directly.
"""
@pytest.fixture
def api_instance(self):
return EnterpriseWorkspaceNoOwnerEmail()
def test_has_post_method(self, api_instance):
"""Test that endpoint has post method"""
assert hasattr(api_instance, "post")
assert callable(api_instance.post)
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
@patch("controllers.inner_api.workspace.workspace.TenantService")
def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask):
"""Test that post() creates a workspace without an owner and returns expected fields"""
# Arrange
now = datetime(2025, 1, 1, 12, 0, 0)
mock_tenant = MagicMock()
mock_tenant.id = "tenant-id"
mock_tenant.name = "My Workspace"
mock_tenant.encrypt_public_key = "pub-key"
mock_tenant.plan = "sandbox"
mock_tenant.status = "normal"
mock_tenant.custom_config = None
mock_tenant.created_at = now
mock_tenant.updated_at = now
mock_tenant_svc.create_tenant.return_value = mock_tenant
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
unwrapped_post = inspect.unwrap(api_instance.post)
with app.test_request_context():
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
mock_ns.payload = {"name": "My Workspace"}
result = unwrapped_post(api_instance)
# Assert
assert result["message"] == "enterprise workspace created."
assert result["tenant"]["id"] == "tenant-id"
assert result["tenant"]["encrypt_public_key"] == "pub-key"
assert result["tenant"]["custom_config"] == {}
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
mock_event.send.assert_called_once_with(mock_tenant)