diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index edf3ac393c..766d95b3dd 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]): def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): def decorator(view_func: Callable[P, R]): + @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): try: data = request.get_json() diff --git a/api/tests/unit_tests/controllers/inner_api/__init__.py b/api/tests/unit_tests/controllers/inner_api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/__init__.py b/api/tests/unit_tests/controllers/inner_api/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py new file mode 100644 index 0000000000..844f04fe72 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py @@ -0,0 +1,313 @@ +""" +Unit tests for inner_api plugin endpoints + +Tests endpoint structure (method existence) for all plugin APIs, plus +handler-level logic tests for representative non-streaming endpoints. +Auth/setup decorators are tested separately in test_auth_wraps.py; +handler tests use inspect.unwrap() to bypass them. +""" + +import inspect +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.inner_api.plugin.plugin import ( + PluginFetchAppInfoApi, + PluginInvokeAppApi, + PluginInvokeEncryptApi, + PluginInvokeLLMApi, + PluginInvokeLLMWithStructuredOutputApi, + PluginInvokeModerationApi, + PluginInvokeParameterExtractorNodeApi, + PluginInvokeQuestionClassifierNodeApi, + PluginInvokeRerankApi, + PluginInvokeSpeech2TextApi, + PluginInvokeSummaryApi, + PluginInvokeTextEmbeddingApi, + PluginInvokeToolApi, + PluginInvokeTTSApi, + PluginUploadFileRequestApi, +) + + +def _extract_raw_post(cls): + """Extract the raw post() method from a plugin endpoint class. + + Plugin endpoint methods are wrapped by several decorators (get_user_tenant, + setup_required, plugin_inner_api_only, plugin_data). These decorators + use @wraps where possible. This helper ensures we retrieve the original + post(self, user_model, tenant_model, payload) function by unwrapping + and, if necessary, walking the closure of the innermost wrapper. + """ + bottom = inspect.unwrap(cls.post) + + # If unwrap() didn't get us to the raw function (e.g. if a decorator + # missed @wraps), try to extract it from the closure if it looks like + # a plugin_data or similar wrapper that closes over 'view_func'. + if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars: + try: + idx = bottom.__code__.co_freevars.index("view_func") + return bottom.__closure__[idx].cell_contents + except (AttributeError, TypeError, IndexError): + pass + + return bottom + + +class TestPluginInvokeLLMApi: + """Test PluginInvokeLLMApi endpoint structure""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeLLMApi() + + def test_has_post_method(self, api_instance): + """Test that endpoint has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeLLMWithStructuredOutputApi: + """Test PluginInvokeLLMWithStructuredOutputApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeLLMWithStructuredOutputApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeTextEmbeddingApi: + """Test PluginInvokeTextEmbeddingApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeTextEmbeddingApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeRerankApi: + """Test PluginInvokeRerankApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeRerankApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeTTSApi: + """Test PluginInvokeTTSApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeTTSApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeSpeech2TextApi: + """Test PluginInvokeSpeech2TextApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeSpeech2TextApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeModerationApi: + """Test PluginInvokeModerationApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeModerationApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeToolApi: + """Test PluginInvokeToolApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeToolApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeParameterExtractorNodeApi: + """Test PluginInvokeParameterExtractorNodeApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeParameterExtractorNodeApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeQuestionClassifierNodeApi: + """Test PluginInvokeQuestionClassifierNodeApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeQuestionClassifierNodeApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeAppApi: + """Test PluginInvokeAppApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeAppApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeEncryptApi: + """Test PluginInvokeEncryptApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeEncryptApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.PluginEncrypter") + def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask): + """Test that post() delegates to PluginEncrypter and returns model_dump output""" + # Arrange + mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"} + mock_tenant = MagicMock() + mock_user = MagicMock() + mock_payload = MagicMock() + + # Act — extract raw post() bypassing all decorators including plugin_data + raw_post = _extract_raw_post(PluginInvokeEncryptApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload) + assert result["data"] == {"encrypted": "data"} + assert result.get("error") == "" + + @patch("controllers.inner_api.plugin.plugin.PluginEncrypter") + def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask): + """Test that post() catches exceptions and returns error response""" + # Arrange + mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed") + mock_tenant = MagicMock() + mock_user = MagicMock() + mock_payload = MagicMock() + + # Act + raw_post = _extract_raw_post(PluginInvokeEncryptApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + assert "encrypt failed" in result["error"] + + +class TestPluginInvokeSummaryApi: + """Test PluginInvokeSummaryApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeSummaryApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginUploadFileRequestApi: + """Test PluginUploadFileRequestApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginUploadFileRequestApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin") + def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask): + """Test that post() generates a signed URL and returns it""" + # Arrange + mock_get_url.return_value = "https://storage.example.com/signed-upload-url" + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_user = MagicMock() + mock_user.id = "user-id" + mock_payload = MagicMock() + mock_payload.filename = "test.pdf" + mock_payload.mimetype = "application/pdf" + + # Act + raw_post = _extract_raw_post(PluginUploadFileRequestApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_get_url.assert_called_once_with( + filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id" + ) + assert result["data"]["url"] == "https://storage.example.com/signed-upload-url" + + +class TestPluginFetchAppInfoApi: + """Test PluginFetchAppInfoApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginFetchAppInfoApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation") + def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask): + """Test that post() fetches app info and returns it""" + # Arrange + mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"} + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_user = MagicMock() + mock_payload = MagicMock() + mock_payload.app_id = "app-123" + + # Act + raw_post = _extract_raw_post(PluginFetchAppInfoApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id") + assert result["data"] == {"app_name": "My App", "mode": "chat"} diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py new file mode 100644 index 0000000000..6de07a23e5 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -0,0 +1,305 @@ +""" +Unit tests for inner_api plugin decorators +""" + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.plugin.wraps import ( + TenantUserPayload, + get_user, + get_user_tenant, + plugin_data, +) + + +class TestTenantUserPayload: + """Test TenantUserPayload Pydantic model""" + + def test_valid_payload(self): + """Test valid payload passes validation""" + data = {"tenant_id": "tenant123", "user_id": "user456"} + payload = TenantUserPayload.model_validate(data) + assert payload.tenant_id == "tenant123" + assert payload.user_id == "user456" + + def test_missing_tenant_id(self): + """Test missing tenant_id raises ValidationError""" + with pytest.raises(ValidationError): + TenantUserPayload.model_validate({"user_id": "user456"}) + + def test_missing_user_id(self): + """Test missing user_id raises ValidationError""" + with pytest.raises(ValidationError): + TenantUserPayload.model_validate({"tenant_id": "tenant123"}) + + +class TestGetUser: + """Test get_user function""" + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + """Test returning existing user when found by ID""" + # Arrange + mock_user = MagicMock() + mock_user.id = "user123" + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.return_value.where.return_value.first.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", "user123") + + # Assert + assert result == mock_user + mock_session.query.assert_called_once() + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_return_existing_anonymous_user_by_session_id( + self, mock_db, mock_session_class, mock_enduser_class, app: Flask + ): + """Test returning existing anonymous user by session_id""" + # Arrange + mock_user = MagicMock() + mock_user.session_id = "anonymous_session" + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.return_value.where.return_value.first.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", "anonymous_session") + + # Assert + assert result == mock_user + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + """Test creating new user when not found in database""" + # Arrange + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.return_value.where.return_value.first.return_value = None + mock_new_user = MagicMock() + mock_enduser_class.return_value = mock_new_user + + # Act + with app.app_context(): + result = get_user("tenant123", "user123") + + # Assert + assert result == mock_new_user + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_use_default_session_id_when_user_id_none( + self, mock_db, mock_session_class, mock_enduser_class, app: Flask + ): + """Test using default session ID when user_id is None""" + # Arrange + mock_user = MagicMock() + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.return_value.where.return_value.first.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", None) + + # Assert + assert result == mock_user + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_raise_error_on_database_exception( + self, mock_db, mock_session_class, mock_enduser_class, app: Flask + ): + """Test raising ValueError when database operation fails""" + # Arrange + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.side_effect = Exception("Database error") + + # Act & Assert + with app.app_context(): + with pytest.raises(ValueError, match="user not found"): + get_user("tenant123", "user123") + + +class TestGetUserTenant: + """Test get_user_tenant decorator""" + + @patch("controllers.inner_api.plugin.wraps.Tenant") + def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch): + """Test that decorator injects tenant_model and user_model into kwargs""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return {"tenant": tenant_model, "user": user_model} + + mock_tenant = MagicMock() + mock_tenant.id = "tenant123" + mock_user = MagicMock() + mock_user.id = "user456" + + # Act + with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}): + monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) + with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: + mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_get_user.return_value = mock_user + result = protected_view() + + # Assert + assert result["tenant"] == mock_tenant + assert result["user"] == mock_user + + def test_should_raise_error_when_tenant_id_missing(self, app: Flask): + """Test that Pydantic ValidationError is raised when tenant_id is missing from payload""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return "success" + + # Act & Assert - Pydantic validates payload before manual check + with app.test_request_context(json={"user_id": "user456"}): + with pytest.raises(ValidationError): + protected_view() + + def test_should_raise_error_when_tenant_not_found(self, app: Flask): + """Test that ValueError is raised when tenant is not found""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return "success" + + # Act & Assert + with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}): + with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + mock_query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="tenant not found"): + protected_view() + + @patch("controllers.inner_api.plugin.wraps.Tenant") + def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch): + """Test that default session ID is used when user_id is empty string""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return {"tenant": tenant_model, "user": user_model} + + mock_tenant = MagicMock() + mock_tenant.id = "tenant123" + mock_user = MagicMock() + + # Act - use empty string for user_id to trigger default logic + with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}): + monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) + with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: + mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_get_user.return_value = mock_user + result = protected_view() + + # Assert + assert result["tenant"] == mock_tenant + assert result["user"] == mock_user + from models.model import DefaultEndUserSessionID + + mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID) + + +class PluginTestPayload: + """Simple test payload class""" + + def __init__(self, data: dict): + self.value = data.get("value") + + @classmethod + def model_validate(cls, data: dict): + return cls(data) + + +class TestPluginData: + """Test plugin_data decorator""" + + def test_should_inject_valid_payload(self, app: Flask): + """Test that valid payload is injected into kwargs""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act + with app.test_request_context(json={"value": "test_data"}): + result = protected_view() + + # Assert + assert result.value == "test_data" + + def test_should_raise_error_on_invalid_json(self, app: Flask): + """Test that ValueError is raised when JSON parsing fails""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act & Assert - Malformed JSON triggers ValueError + with app.test_request_context(data="not valid json", content_type="application/json"): + with pytest.raises(ValueError): + protected_view() + + def test_should_raise_error_on_invalid_payload(self, app: Flask): + """Test that ValueError is raised when payload validation fails""" + + # Arrange + class InvalidPayload: + @classmethod + def model_validate(cls, data: dict): + raise Exception("Validation failed") + + @plugin_data(payload_type=InvalidPayload) + def protected_view(payload, **kwargs): + return payload + + # Act & Assert + with app.test_request_context(json={"data": "test"}): + with pytest.raises(ValueError, match="invalid payload"): + protected_view() + + def test_should_work_as_parameterized_decorator(self, app: Flask): + """Test that decorator works when used with parentheses""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act + with app.test_request_context(json={"value": "parameterized"}): + result = protected_view() + + # Assert + assert result.value == "parameterized" diff --git a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py new file mode 100644 index 0000000000..883ccdea2c --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py @@ -0,0 +1,309 @@ +""" +Unit tests for inner_api auth decorators +""" + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import HTTPException + +from configs import dify_config +from controllers.inner_api.wraps import ( + billing_inner_api_only, + enterprise_inner_api_only, + enterprise_inner_api_user_auth, + plugin_inner_api_only, +) + + +class TestBillingInnerApiOnly: + """Test billing_inner_api_only decorator""" + + def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask): + """Test that valid API key allows access when INNER_API is enabled""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_inner_api_disabled(self, app: Flask): + """Test that 404 is returned when INNER_API is disabled""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_401_when_api_key_missing(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is missing""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + def test_should_return_401_when_api_key_invalid(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is invalid""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + +class TestEnterpriseInnerApiOnly: + """Test enterprise_inner_api_only decorator""" + + def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask): + """Test that valid API key allows access when INNER_API is enabled""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_inner_api_disabled(self, app: Flask): + """Test that 404 is returned when INNER_API is disabled""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_401_when_api_key_missing(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is missing""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + def test_should_return_401_when_api_key_invalid(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is invalid""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + +class TestEnterpriseInnerApiUserAuth: + """Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication""" + + def test_should_pass_through_when_inner_api_disabled(self, app: Flask): + """Test that request passes through when INNER_API is disabled""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_authorization_header_missing(self, app: Flask): + """Test that request passes through when Authorization header is missing""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_authorization_format_invalid(self, app: Flask): + """Test that request passes through when Authorization format is invalid (no colon)""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(headers={"Authorization": "invalid_format"}): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask): + """Test that request passes through when HMAC signature is invalid""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act - use wrong signature + with app.test_request_context( + headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"} + ): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_inject_user_when_hmac_signature_valid(self, app: Flask): + """Test that user is injected when HMAC signature is valid""" + # Arrange + from base64 import b64encode + from hashlib import sha1 + from hmac import new as hmac_new + + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user") + + # Calculate valid HMAC signature + user_id = "user123" + inner_api_key = "valid_key" + data_to_sign = f"DIFY {user_id}" + signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) + valid_signature = b64encode(signature.digest()).decode("utf-8") + + # Create mock user + mock_user = MagicMock() + mock_user.id = user_id + + # Act + with app.test_request_context( + headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key} + ): + with patch.object(dify_config, "INNER_API", True): + with patch("controllers.inner_api.wraps.db.session.query") as mock_query: + mock_query.return_value.where.return_value.first.return_value = mock_user + result = protected_view() + + # Assert + assert result == mock_user + + +class TestPluginInnerApiOnly: + """Test plugin_inner_api_only decorator""" + + def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask): + """Test that valid API key allows access when PLUGIN_DAEMON_KEY is set""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"): + with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask): + """Test that 404 is returned when PLUGIN_DAEMON_KEY is not set""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_404_when_api_key_invalid(self, app: Flask): + """Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"): + with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 diff --git a/api/tests/unit_tests/controllers/inner_api/test_mail.py b/api/tests/unit_tests/controllers/inner_api/test_mail.py new file mode 100644 index 0000000000..c2ca35693e --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/test_mail.py @@ -0,0 +1,206 @@ +""" +Unit tests for inner_api mail module +""" + +from unittest.mock import patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.mail import ( + BaseMail, + BillingMail, + EnterpriseMail, + InnerMailPayload, +) + + +class TestInnerMailPayload: + """Test InnerMailPayload Pydantic model""" + + def test_valid_payload_with_all_fields(self): + """Test valid payload with all fields passes validation""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + "substitutions": {"key": "value"}, + } + payload = InnerMailPayload.model_validate(data) + assert payload.to == ["test@example.com"] + assert payload.subject == "Test Subject" + assert payload.body == "Test Body" + assert payload.substitutions == {"key": "value"} + + def test_valid_payload_without_substitutions(self): + """Test valid payload without optional substitutions""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + payload = InnerMailPayload.model_validate(data) + assert payload.to == ["test@example.com"] + assert payload.subject == "Test Subject" + assert payload.body == "Test Body" + assert payload.substitutions is None + + def test_empty_to_list_fails_validation(self): + """Test that empty 'to' list fails validation due to min_length=1""" + data = { + "to": [], + "subject": "Test Subject", + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_multiple_recipients_allowed(self): + """Test that multiple recipients are allowed""" + data = { + "to": ["user1@example.com", "user2@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + payload = InnerMailPayload.model_validate(data) + assert len(payload.to) == 2 + assert "user1@example.com" in payload.to + assert "user2@example.com" in payload.to + + def test_missing_to_field_fails_validation(self): + """Test that missing 'to' field fails validation""" + data = { + "subject": "Test Subject", + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_missing_subject_fails_validation(self): + """Test that missing 'subject' field fails validation""" + data = { + "to": ["test@example.com"], + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_missing_body_fails_validation(self): + """Test that missing 'body' field fails validation""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + +class TestBaseMail: + """Test BaseMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create BaseMail API instance""" + return BaseMail() + + @patch("controllers.inner_api.mail.send_inner_email_task") + def test_post_sends_email_task(self, mock_task, api_instance, app: Flask): + """Test that POST sends inner email task""" + # Arrange + mock_task.delay.return_value = None + + # Act + with app.test_request_context( + json={ + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + ): + with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns: + mock_ns.payload = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + result = api_instance.post() + + # Assert + assert result == ({"message": "success"}, 200) + mock_task.delay.assert_called_once_with( + to=["test@example.com"], + subject="Test Subject", + body="Test Body", + substitutions=None, + ) + + @patch("controllers.inner_api.mail.send_inner_email_task") + def test_post_with_substitutions(self, mock_task, api_instance, app: Flask): + """Test that POST sends email with substitutions""" + # Arrange + mock_task.delay.return_value = None + + # Act + with app.test_request_context(): + with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns: + mock_ns.payload = { + "to": ["test@example.com"], + "subject": "Hello {{name}}", + "body": "Welcome {{name}}!", + "substitutions": {"name": "John"}, + } + result = api_instance.post() + + # Assert + assert result == ({"message": "success"}, 200) + mock_task.delay.assert_called_once_with( + to=["test@example.com"], + subject="Hello {{name}}", + body="Welcome {{name}}!", + substitutions={"name": "John"}, + ) + + +class TestEnterpriseMail: + """Test EnterpriseMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create EnterpriseMail API instance""" + return EnterpriseMail() + + def test_has_enterprise_inner_api_only_decorator(self, api_instance): + """Test that EnterpriseMail has enterprise_inner_api_only decorator""" + # Check method_decorators + from controllers.inner_api.wraps import enterprise_inner_api_only + + assert enterprise_inner_api_only in api_instance.method_decorators + + def test_has_setup_required_decorator(self, api_instance): + """Test that EnterpriseMail has setup_required decorator""" + # Check by decorator name instead of object reference + decorator_names = [d.__name__ for d in api_instance.method_decorators] + assert "setup_required" in decorator_names + + +class TestBillingMail: + """Test BillingMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create BillingMail API instance""" + return BillingMail() + + def test_has_billing_inner_api_only_decorator(self, api_instance): + """Test that BillingMail has billing_inner_api_only decorator""" + # Check method_decorators + from controllers.inner_api.wraps import billing_inner_api_only + + assert billing_inner_api_only in api_instance.method_decorators + + def test_has_setup_required_decorator(self, api_instance): + """Test that BillingMail has setup_required decorator""" + # Check by decorator name instead of object reference + decorator_names = [d.__name__ for d in api_instance.method_decorators] + assert "setup_required" in decorator_names diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/__init__.py b/api/tests/unit_tests/controllers/inner_api/workspace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py new file mode 100644 index 0000000000..4fbf0f7125 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -0,0 +1,184 @@ +""" +Unit tests for inner_api workspace module + +Tests Pydantic model validation and endpoint handler logic. +Auth/setup decorators are tested separately in test_auth_wraps.py; +handler tests use inspect.unwrap() to bypass them and focus on business logic. +""" + +import inspect +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.workspace.workspace import ( + EnterpriseWorkspace, + EnterpriseWorkspaceNoOwnerEmail, + WorkspaceCreatePayload, + WorkspaceOwnerlessPayload, +) + + +class TestWorkspaceCreatePayload: + """Test WorkspaceCreatePayload Pydantic model validation""" + + def test_valid_payload(self): + """Test valid payload with all fields passes validation""" + data = { + "name": "My Workspace", + "owner_email": "owner@example.com", + } + payload = WorkspaceCreatePayload.model_validate(data) + assert payload.name == "My Workspace" + assert payload.owner_email == "owner@example.com" + + def test_missing_name_fails_validation(self): + """Test that missing name fails validation""" + data = {"owner_email": "owner@example.com"} + with pytest.raises(ValidationError) as exc_info: + WorkspaceCreatePayload.model_validate(data) + assert "name" in str(exc_info.value) + + def test_missing_owner_email_fails_validation(self): + """Test that missing owner_email fails validation""" + data = {"name": "My Workspace"} + with pytest.raises(ValidationError) as exc_info: + WorkspaceCreatePayload.model_validate(data) + assert "owner_email" in str(exc_info.value) + + +class TestWorkspaceOwnerlessPayload: + """Test WorkspaceOwnerlessPayload Pydantic model validation""" + + def test_valid_payload(self): + """Test valid payload with name passes validation""" + data = {"name": "My Workspace"} + payload = WorkspaceOwnerlessPayload.model_validate(data) + assert payload.name == "My Workspace" + + def test_missing_name_fails_validation(self): + """Test that missing name fails validation""" + data = {} + with pytest.raises(ValidationError) as exc_info: + WorkspaceOwnerlessPayload.model_validate(data) + assert "name" in str(exc_info.value) + + +class TestEnterpriseWorkspace: + """Test EnterpriseWorkspace API endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py) + and exercise the core business logic directly. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseWorkspace() + + def test_has_post_method(self, api_instance): + """Test that EnterpriseWorkspace has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.workspace.workspace.tenant_was_created") + @patch("controllers.inner_api.workspace.workspace.TenantService") + @patch("controllers.inner_api.workspace.workspace.db") + def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask): + """Test that post() creates a workspace and assigns the owner account""" + # Arrange + mock_account = MagicMock() + mock_account.email = "owner@example.com" + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + + now = datetime(2025, 1, 1, 12, 0, 0) + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_tenant.name = "My Workspace" + mock_tenant.plan = "sandbox" + mock_tenant.status = "normal" + mock_tenant.created_at = now + mock_tenant.updated_at = now + mock_tenant_svc.create_tenant.return_value = mock_tenant + + # Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py) + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"} + result = unwrapped_post(api_instance) + + # Assert + assert result["message"] == "enterprise workspace created." + assert result["tenant"]["id"] == "tenant-id" + assert result["tenant"]["name"] == "My Workspace" + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner") + mock_event.send.assert_called_once_with(mock_tenant) + + @patch("controllers.inner_api.workspace.workspace.db") + def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask): + """Test that post() returns 404 when the owner account does not exist""" + # Arrange + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"} + result = unwrapped_post(api_instance) + + # Assert + assert result == ({"message": "owner account not found."}, 404) + + +class TestEnterpriseWorkspaceNoOwnerEmail: + """Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py) + and exercise the core business logic directly. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseWorkspaceNoOwnerEmail() + + def test_has_post_method(self, api_instance): + """Test that endpoint has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.workspace.workspace.tenant_was_created") + @patch("controllers.inner_api.workspace.workspace.TenantService") + def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask): + """Test that post() creates a workspace without an owner and returns expected fields""" + # Arrange + now = datetime(2025, 1, 1, 12, 0, 0) + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_tenant.name = "My Workspace" + mock_tenant.encrypt_public_key = "pub-key" + mock_tenant.plan = "sandbox" + mock_tenant.status = "normal" + mock_tenant.custom_config = None + mock_tenant.created_at = now + mock_tenant.updated_at = now + mock_tenant_svc.create_tenant.return_value = mock_tenant + + # Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py) + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace"} + result = unwrapped_post(api_instance) + + # Assert + assert result["message"] == "enterprise workspace created." + assert result["tenant"]["id"] == "tenant-id" + assert result["tenant"]["encrypt_public_key"] == "pub-key" + assert result["tenant"]["custom_config"] == {} + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_event.send.assert_called_once_with(mock_tenant)