From 41f827b609a90c4645b89e66158d5895156cd2a8 Mon Sep 17 00:00:00 2001 From: GareArc Date: Wed, 10 Jun 2026 02:15:44 -0700 Subject: [PATCH] feat(openapi): add OpenApiErrorFormatter normalizing all error paths to ErrorBody --- api/controllers/openapi/_errors.py | 130 ++++++++++++++++ .../openapi/test_error_contract.py | 144 +++++++++++++++++- 2 files changed, 273 insertions(+), 1 deletion(-) diff --git a/api/controllers/openapi/_errors.py b/api/controllers/openapi/_errors.py index 6aac3044ad..d48cda4c83 100644 --- a/api/controllers/openapi/_errors.py +++ b/api/controllers/openapi/_errors.py @@ -9,8 +9,12 @@ pre-existing ``e.data`` override the registered handler's return value. """ from enum import StrEnum +from typing import Any from pydantic import BaseModel +from werkzeug.exceptions import HTTPException + +from libs.external_api import http_status_message class OpenApiErrorCode(StrEnum): @@ -64,3 +68,129 @@ class ErrorBody(BaseModel): status: int hint: str | None = None details: list[ErrorDetail] | None = None + + +_CODE_BY_STATUS: dict[int, OpenApiErrorCode] = { + 400: OpenApiErrorCode.BAD_REQUEST, + 401: OpenApiErrorCode.UNAUTHORIZED, + 403: OpenApiErrorCode.FORBIDDEN, + 404: OpenApiErrorCode.NOT_FOUND, + 405: OpenApiErrorCode.METHOD_NOT_ALLOWED, + 406: OpenApiErrorCode.NOT_ACCEPTABLE, + 409: OpenApiErrorCode.CONFLICT, + 413: OpenApiErrorCode.REQUEST_TOO_LARGE, + 415: OpenApiErrorCode.UNSUPPORTED_MEDIA_TYPE, + 422: OpenApiErrorCode.INVALID_PARAM, + 429: OpenApiErrorCode.TOO_MANY_REQUESTS, + 500: OpenApiErrorCode.INTERNAL_ERROR, + 502: OpenApiErrorCode.BAD_GATEWAY, +} + +_GENERIC_500_MESSAGE = "Internal Server Error" + + +class OpenApiError(HTTPException): + """Dedicated throwable for the /openapi/v1 surface. + + A subclass declares ``code`` (HTTP status), ``error_code`` and + ``description`` exactly once; call sites just ``raise SomeError()`` — + no per-site dict building, no duplicated message constants. The + formatter emits all three (plus optional ``hint``/``details``) verbatim. + """ + + code = 400 + error_code: OpenApiErrorCode = OpenApiErrorCode.UNKNOWN + hint: str | None = None + + def __init__( + self, + message: str | None = None, + *, + hint: str | None = None, + details: list[ErrorDetail] | None = None, + ) -> None: + super().__init__(description=message) + if hint is not None: + self.hint = hint + self.details = details + + +class OpenApiErrorFormatter: + """Builds the canonical ErrorBody from whatever the shared handlers computed. + + Resolution order for ``code``: explicit ``error_code`` class attribute + (BaseHTTPException subclasses and OpenApiError subclasses) → HTTP status + map → ``unknown``. Class-name-derived codes from the shared handler are + deliberately ignored — they are not a stable contract. + """ + + def finalize(self, e: Exception, data: dict[str, Any], status_code: int) -> dict[str, Any]: + exc_data = getattr(e, "data", None) + merged: dict[str, Any] = {**data, **exc_data} if isinstance(exc_data, dict) else dict(data) + + body = ErrorBody( + code=self._resolve_code(e, status_code), + message=self._resolve_message(merged, status_code), + status=status_code, + hint=self._resolve_hint(e), + details=self._extract_details(e, merged), + ) + wire = body.model_dump(mode="json", exclude_none=True) + + # flask-restx Api.handle_error does `data = getattr(e, "data", default_data)` + # AFTER our handler returns, so a pre-existing e.data (flask_restx.abort, + # BaseHTTPException) would override the canonical body. Rewrite it. + try: + e.data = wire # type: ignore[attr-defined] + except AttributeError: + pass + return wire + + def _resolve_code(self, e: Exception, status_code: int) -> str: + explicit = getattr(type(e), "error_code", None) + if isinstance(explicit, (OpenApiErrorCode, str)) and str(explicit) != "unknown": + return str(explicit) + return str(_CODE_BY_STATUS.get(status_code, OpenApiErrorCode.UNKNOWN)) + + def _resolve_message(self, merged: dict[str, Any], status_code: int) -> str: + if status_code >= 500: + return _GENERIC_500_MESSAGE + message = merged.get("message") + if isinstance(message, str) and message: + return message + return http_status_message(status_code) or "request failed" + + def _resolve_hint(self, e: Exception) -> str | None: + hint = getattr(e, "hint", None) + return hint if isinstance(hint, str) and hint else None + + def _extract_details(self, e: Exception, merged: dict[str, Any]) -> list[ErrorDetail] | None: + explicit = getattr(e, "details", None) + if isinstance(explicit, list) and explicit and all(isinstance(d, ErrorDetail) for d in explicit): + return explicit + # an already-canonical body (e.g. e.data rewritten by a prior finalize) + # carries "details"; re-validate so finalize stays idempotent + canonical = merged.get("details") + if isinstance(canonical, list) and canonical and all(isinstance(d, dict) for d in canonical): + return [ErrorDetail.model_validate(d) for d in canonical] + errors = merged.get("errors") + if isinstance(errors, list) and errors: + details = [ + ErrorDetail( + type=str(item.get("type", "invalid")), + loc=[part for part in item.get("loc", []) if self._is_loc_part(part)], + msg=str(item.get("msg", "")), + ) + for item in errors + if isinstance(item, dict) + ] + return details or None + params = merged.get("params") + if isinstance(params, str) and params: + return [ErrorDetail(type="invalid", loc=[params], msg=str(merged.get("message", "")))] + return None + + @staticmethod + def _is_loc_part(part: Any) -> bool: + # bool is an int subclass but is not a valid path segment + return isinstance(part, (str, int)) and not isinstance(part, bool) diff --git a/api/tests/unit_tests/controllers/openapi/test_error_contract.py b/api/tests/unit_tests/controllers/openapi/test_error_contract.py index 0709d82cae..d559385d10 100644 --- a/api/tests/unit_tests/controllers/openapi/test_error_contract.py +++ b/api/tests/unit_tests/controllers/openapi/test_error_contract.py @@ -1,6 +1,10 @@ """Wire-contract tests for the canonical /openapi/v1 error body.""" -from controllers.openapi._errors import ErrorBody, ErrorDetail, OpenApiErrorCode +import pytest +from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity + +from controllers.openapi._errors import ErrorBody, ErrorDetail, OpenApiError, OpenApiErrorCode, OpenApiErrorFormatter +from controllers.web.error import ProviderQuotaExceededError class TestErrorBodyModel: @@ -31,3 +35,141 @@ class TestErrorBodyModel: body = ErrorBody.model_validate({"code": "some_future_code", "message": "x", "status": 400}) assert body.code == "some_future_code" + + +class TestOpenApiErrorFormatter: + @pytest.fixture + def fmt(self): + return OpenApiErrorFormatter() + + def test_plain_werkzeug_exception_maps_code_from_status(self, fmt): + e = NotFound("app not found") + data = {"code": "not_found", "message": "app not found", "status": 404} + + wire = fmt.finalize(e, data, 404) + + assert wire == {"code": "not_found", "message": "app not found", "status": 404} + + def test_422_maps_to_invalid_param(self, fmt): + e = UnprocessableEntity("workspace_id is required for name-based lookup") + data = {"code": "unprocessable_entity", "message": e.description, "status": 422} + + wire = fmt.finalize(e, data, 422) + + assert wire["code"] == "invalid_param" + + def test_flask_restx_abort_data_path_yields_canonical_body(self, fmt): + # Simulates _contract.py's abort(422, message=..., errors=...): flask_restx + # attaches kwargs to e.data, which handle_error would otherwise put on the + # wire verbatim (no code/status). + e = UnprocessableEntity() + e.data = { + "message": "Request validation failed", + "errors": [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1", "extra": "drop me"}], + } + data = {"code": "unprocessable_entity", "message": e.description, "status": 422} + + wire = fmt.finalize(e, data, 422) + + assert wire["code"] == "invalid_param" + assert wire["message"] == "Request validation failed" + assert wire["status"] == 422 + assert wire["details"] == [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1"}] + # the override channel now carries the canonical body + assert e.data == wire + + def test_finalize_is_idempotent(self, fmt): + e = UnprocessableEntity() + e.data = { + "message": "Request validation failed", + "errors": [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1"}], + } + data = {"code": "unprocessable_entity", "message": e.description, "status": 422} + + first = fmt.finalize(e, data, 422) + second = fmt.finalize(e, data, 422) + + assert second == first + + def test_base_http_exception_error_code_wins_over_status_map(self, fmt): + e = ProviderQuotaExceededError() + data = dict(e.data) + + wire = fmt.finalize(e, data, 400) + + assert wire["code"] == "provider_quota_exceeded" + assert wire["status"] == 400 + + def test_hint_attribute_is_emitted(self, fmt): + e = Conflict("seat limit") + e.hint = "remove a member first" + data = {"code": "conflict", "message": "seat limit", "status": 409} + + wire = fmt.finalize(e, data, 409) + + assert wire["hint"] == "remove a member first" + + def test_params_shape_becomes_details(self, fmt): + e = ValueError("is required") + data = {"code": "invalid_param", "message": "is required", "params": "email", "status": 400} + + wire = fmt.finalize(e, data, 400) + + assert "params" not in wire + assert wire["details"] == [{"type": "invalid", "loc": ["email"], "msg": "is required"}] + + def test_catch_all_exception_never_leaks_str_e(self, fmt): + e = RuntimeError("postgres password=hunter2 connection refused") + data = {"message": str(e), "code": "unknown", "status": 500} + + wire = fmt.finalize(e, data, 500) + + assert wire["code"] == "internal_server_error" + assert "hunter2" not in wire["message"] + + def test_unmapped_status_falls_back_to_unknown(self, fmt): + from werkzeug.exceptions import Gone + + e = Gone() + data = {"code": "gone", "message": e.description, "status": 410} + + wire = fmt.finalize(e, data, 410) + + assert wire["code"] == "unknown" + + def test_openapi_error_subclass_is_throw_and_done(self, fmt): + # The dedicated throwable: subclass declares status + code + message once, + # call sites just `raise`; the formatter emits everything verbatim. + class TeapotError(OpenApiError): + code = 418 + error_code = OpenApiErrorCode.INVALID_PARAM + description = "kettle says no" + + e = TeapotError(details=[ErrorDetail(type="invalid", loc=["kettle"], msg="too hot")]) + data = {"code": "im_a_teapot", "message": e.description, "status": 418} + + wire = fmt.finalize(e, data, 418) + + assert wire["code"] == OpenApiErrorCode.INVALID_PARAM + assert wire["message"] == TeapotError.description + assert wire["details"] == [{"type": "invalid", "loc": ["kettle"], "msg": "too hot"}] + + def test_openapi_error_message_override(self, fmt): + e = OpenApiError("custom reason") + data = {"code": "bad_request", "message": e.description, "status": 400} + + wire = fmt.finalize(e, data, 400) + + assert wire["message"] == "custom reason" + assert wire["code"] == "bad_request" + + def test_every_emitted_code_is_an_enum_member(self, fmt): + # Guard against the formatter inventing codes outside the contract. + cases = [ + (NotFound("x"), {"code": "not_found", "message": "x", "status": 404}, 404), + (ProviderQuotaExceededError(), dict(ProviderQuotaExceededError().data), 400), + (ValueError("x"), {"code": "invalid_param", "message": "x", "status": 400}, 400), + ] + for e, data, status in cases: + wire = fmt.finalize(e, data, status) + assert wire["code"] in {c.value for c in OpenApiErrorCode}