diff --git a/api/controllers/openapi/_errors.py b/api/controllers/openapi/_errors.py index d48cda4c83..dcadfce1e8 100644 --- a/api/controllers/openapi/_errors.py +++ b/api/controllers/openapi/_errors.py @@ -194,3 +194,17 @@ class OpenApiErrorFormatter: def _is_loc_part(part: Any) -> bool: # bool is an int subclass but is not a valid path segment return isinstance(part, (str, int)) and not isinstance(part, bool) + + +class MemberLimitExceeded(OpenApiError): # noqa: N818 + code = 403 + error_code = OpenApiErrorCode.MEMBER_LIMIT_EXCEEDED + description = "Subscription member limit reached." + hint = "Upgrade your plan to invite more members or remove an existing member first." + + +class MemberLicenseExceeded(OpenApiError): # noqa: N818 + code = 403 + error_code = OpenApiErrorCode.MEMBER_LICENSE_EXCEEDED + description = "Workspace member license capacity reached." + hint = "Contact your workspace administrator to expand the license seat count." diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index 902337703a..0ff225271d 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -14,13 +14,13 @@ from __future__ import annotations from itertools import starmap from urllib import parse -from flask import jsonify, make_response from flask_restx import Resource -from werkzeug.exceptions import BadRequest, Forbidden, NotFound +from werkzeug.exceptions import BadRequest, NotFound from configs import dify_config from controllers.openapi import openapi_ns from controllers.openapi._contract import accepts, returns +from controllers.openapi._errors import MemberLicenseExceeded, MemberLimitExceeded from controllers.openapi._models import ( MemberActionResponse, MemberInvitePayload, @@ -77,34 +77,16 @@ def _load_account(account_id: object) -> Account: return account -def _quota_error(*, code: str, message: str, hint: str) -> Forbidden: - err = Forbidden(message) - err.response = make_response( - jsonify({"code": code, "message": message, "hint": hint}), - 403, - ) - return err - - def _check_member_invite_quota(tenant_id: str) -> None: features = FeatureService.get_features(tenant_id) if features.billing.enabled: members = features.members if 0 < members.limit <= members.size: - raise _quota_error( - code="members.limit_exceeded", - message="Subscription member limit reached.", - hint="Upgrade your plan to invite more members or remove an existing member first.", - ) + raise MemberLimitExceeded() - if features.workspace_members.enabled: - if not features.workspace_members.is_available(1): - raise _quota_error( - code="workspace_members.license_exceeded", - message="Workspace member license capacity reached.", - hint="Contact your workspace administrator to expand the license seat count.", - ) + if features.workspace_members.enabled and not features.workspace_members.is_available(1): + raise MemberLicenseExceeded() @openapi_ns.route("/workspaces") diff --git a/api/tests/unit_tests/controllers/openapi/test_error_contract.py b/api/tests/unit_tests/controllers/openapi/test_error_contract.py index 3165ac405c..684a123012 100644 --- a/api/tests/unit_tests/controllers/openapi/test_error_contract.py +++ b/api/tests/unit_tests/controllers/openapi/test_error_contract.py @@ -5,7 +5,15 @@ from unittest.mock import MagicMock, patch import pytest from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity -from controllers.openapi._errors import ErrorBody, ErrorDetail, OpenApiError, OpenApiErrorCode, OpenApiErrorFormatter +from controllers.openapi._errors import ( + ErrorBody, + ErrorDetail, + MemberLicenseExceeded, + MemberLimitExceeded, + OpenApiError, + OpenApiErrorCode, + OpenApiErrorFormatter, +) from controllers.web.error import ProviderQuotaExceededError @@ -177,6 +185,25 @@ class TestOpenApiErrorFormatter: assert wire["code"] in {c.value for c in OpenApiErrorCode} +class TestQuotaExceptions: + @pytest.fixture + def fmt(self): + return OpenApiErrorFormatter() + + @pytest.mark.parametrize("exc_class", [MemberLimitExceeded, MemberLicenseExceeded]) + def test_quota_exception_carries_declared_code_and_message(self, fmt, exc_class): + # Single source: assertions read the class attributes, no re-typed strings. + e = exc_class() + data = {"code": "forbidden", "message": e.description, "status": 403} + + wire = fmt.finalize(e, data, 403) + + assert wire["code"] == exc_class.error_code + assert wire["message"] == exc_class.description + assert wire["hint"] == exc_class.hint + assert wire["status"] == 403 + + class TestWireContract: """End-to-end: request in, canonical JSON out, through the real openapi blueprint.""" diff --git a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py index 6bb13ad322..4c09491ab5 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py +++ b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py @@ -29,9 +29,10 @@ import pytest from flask import Flask from flask.views import MethodView from pydantic import ValidationError -from werkzeug.exceptions import BadRequest, Forbidden, NotFound, UnprocessableEntity +from werkzeug.exceptions import BadRequest, NotFound, UnprocessableEntity from controllers.openapi import bp as openapi_bp +from controllers.openapi._errors import MemberLicenseExceeded, MemberLimitExceeded from controllers.openapi._models import MemberInvitePayload, MemberRoleUpdatePayload from controllers.openapi.workspaces import ( WorkspaceMemberApi, @@ -507,11 +508,7 @@ def _invite_request(app, ws_id: str, acct_id: uuid.UUID): def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch): - """SaaS billing plan member cap → 403 with `members.limit_exceeded`. - - Verifies the envelope shape the CLI error-mapper relies on (code + - message + hint on the wire body). - """ + """SaaS billing plan member cap → MemberLimitExceeded (403).""" ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMembersApi() @@ -538,18 +535,14 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch): with _invite_request(app, ws_id, acct_id): _seed(_auth_ctx(account_id=acct_id)) - with pytest.raises(Forbidden) as exc_info: + with pytest.raises(MemberLimitExceeded): api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) - body = exc_info.value.response.json - assert body["code"] == "members.limit_exceeded" - assert "Subscription member limit" in body["message"] - assert body["hint"] invite_mock.assert_not_called() def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, monkeypatch): - """EE License workspace_members cap → 403 with `workspace_members.license_exceeded`. + """EE License workspace_members cap → MemberLicenseExceeded (403). Note: billing.enabled is False (EE without SaaS billing); only the license cap fires. @@ -584,13 +577,9 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo with _invite_request(app, ws_id, acct_id): _seed(_auth_ctx(account_id=acct_id)) - with pytest.raises(Forbidden) as exc_info: + with pytest.raises(MemberLicenseExceeded): api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) - body = exc_info.value.response.json - assert body["code"] == "workspace_members.license_exceeded" - assert "license" in body["message"].lower() - assert body["hint"] invite_mock.assert_not_called()