feat(openapi): add OpenApiErrorFormatter normalizing all error paths to ErrorBody

This commit is contained in:
GareArc 2026-06-10 02:15:44 -07:00
parent 24b6e6f983
commit 41f827b609
No known key found for this signature in database
2 changed files with 273 additions and 1 deletions

View File

@ -9,8 +9,12 @@ pre-existing ``e.data`` override the registered handler's return value.
"""
from enum import StrEnum
from typing import Any
from pydantic import BaseModel
from werkzeug.exceptions import HTTPException
from libs.external_api import http_status_message
class OpenApiErrorCode(StrEnum):
@ -64,3 +68,129 @@ class ErrorBody(BaseModel):
status: int
hint: str | None = None
details: list[ErrorDetail] | None = None
_CODE_BY_STATUS: dict[int, OpenApiErrorCode] = {
400: OpenApiErrorCode.BAD_REQUEST,
401: OpenApiErrorCode.UNAUTHORIZED,
403: OpenApiErrorCode.FORBIDDEN,
404: OpenApiErrorCode.NOT_FOUND,
405: OpenApiErrorCode.METHOD_NOT_ALLOWED,
406: OpenApiErrorCode.NOT_ACCEPTABLE,
409: OpenApiErrorCode.CONFLICT,
413: OpenApiErrorCode.REQUEST_TOO_LARGE,
415: OpenApiErrorCode.UNSUPPORTED_MEDIA_TYPE,
422: OpenApiErrorCode.INVALID_PARAM,
429: OpenApiErrorCode.TOO_MANY_REQUESTS,
500: OpenApiErrorCode.INTERNAL_ERROR,
502: OpenApiErrorCode.BAD_GATEWAY,
}
_GENERIC_500_MESSAGE = "Internal Server Error"
class OpenApiError(HTTPException):
"""Dedicated throwable for the /openapi/v1 surface.
A subclass declares ``code`` (HTTP status), ``error_code`` and
``description`` exactly once; call sites just ``raise SomeError()``
no per-site dict building, no duplicated message constants. The
formatter emits all three (plus optional ``hint``/``details``) verbatim.
"""
code = 400
error_code: OpenApiErrorCode = OpenApiErrorCode.UNKNOWN
hint: str | None = None
def __init__(
self,
message: str | None = None,
*,
hint: str | None = None,
details: list[ErrorDetail] | None = None,
) -> None:
super().__init__(description=message)
if hint is not None:
self.hint = hint
self.details = details
class OpenApiErrorFormatter:
"""Builds the canonical ErrorBody from whatever the shared handlers computed.
Resolution order for ``code``: explicit ``error_code`` class attribute
(BaseHTTPException subclasses and OpenApiError subclasses) HTTP status
map ``unknown``. Class-name-derived codes from the shared handler are
deliberately ignored they are not a stable contract.
"""
def finalize(self, e: Exception, data: dict[str, Any], status_code: int) -> dict[str, Any]:
exc_data = getattr(e, "data", None)
merged: dict[str, Any] = {**data, **exc_data} if isinstance(exc_data, dict) else dict(data)
body = ErrorBody(
code=self._resolve_code(e, status_code),
message=self._resolve_message(merged, status_code),
status=status_code,
hint=self._resolve_hint(e),
details=self._extract_details(e, merged),
)
wire = body.model_dump(mode="json", exclude_none=True)
# flask-restx Api.handle_error does `data = getattr(e, "data", default_data)`
# AFTER our handler returns, so a pre-existing e.data (flask_restx.abort,
# BaseHTTPException) would override the canonical body. Rewrite it.
try:
e.data = wire # type: ignore[attr-defined]
except AttributeError:
pass
return wire
def _resolve_code(self, e: Exception, status_code: int) -> str:
explicit = getattr(type(e), "error_code", None)
if isinstance(explicit, (OpenApiErrorCode, str)) and str(explicit) != "unknown":
return str(explicit)
return str(_CODE_BY_STATUS.get(status_code, OpenApiErrorCode.UNKNOWN))
def _resolve_message(self, merged: dict[str, Any], status_code: int) -> str:
if status_code >= 500:
return _GENERIC_500_MESSAGE
message = merged.get("message")
if isinstance(message, str) and message:
return message
return http_status_message(status_code) or "request failed"
def _resolve_hint(self, e: Exception) -> str | None:
hint = getattr(e, "hint", None)
return hint if isinstance(hint, str) and hint else None
def _extract_details(self, e: Exception, merged: dict[str, Any]) -> list[ErrorDetail] | None:
explicit = getattr(e, "details", None)
if isinstance(explicit, list) and explicit and all(isinstance(d, ErrorDetail) for d in explicit):
return explicit
# an already-canonical body (e.g. e.data rewritten by a prior finalize)
# carries "details"; re-validate so finalize stays idempotent
canonical = merged.get("details")
if isinstance(canonical, list) and canonical and all(isinstance(d, dict) for d in canonical):
return [ErrorDetail.model_validate(d) for d in canonical]
errors = merged.get("errors")
if isinstance(errors, list) and errors:
details = [
ErrorDetail(
type=str(item.get("type", "invalid")),
loc=[part for part in item.get("loc", []) if self._is_loc_part(part)],
msg=str(item.get("msg", "")),
)
for item in errors
if isinstance(item, dict)
]
return details or None
params = merged.get("params")
if isinstance(params, str) and params:
return [ErrorDetail(type="invalid", loc=[params], msg=str(merged.get("message", "")))]
return None
@staticmethod
def _is_loc_part(part: Any) -> bool:
# bool is an int subclass but is not a valid path segment
return isinstance(part, (str, int)) and not isinstance(part, bool)

View File

@ -1,6 +1,10 @@
"""Wire-contract tests for the canonical /openapi/v1 error body."""
from controllers.openapi._errors import ErrorBody, ErrorDetail, OpenApiErrorCode
import pytest
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
from controllers.openapi._errors import ErrorBody, ErrorDetail, OpenApiError, OpenApiErrorCode, OpenApiErrorFormatter
from controllers.web.error import ProviderQuotaExceededError
class TestErrorBodyModel:
@ -31,3 +35,141 @@ class TestErrorBodyModel:
body = ErrorBody.model_validate({"code": "some_future_code", "message": "x", "status": 400})
assert body.code == "some_future_code"
class TestOpenApiErrorFormatter:
@pytest.fixture
def fmt(self):
return OpenApiErrorFormatter()
def test_plain_werkzeug_exception_maps_code_from_status(self, fmt):
e = NotFound("app not found")
data = {"code": "not_found", "message": "app not found", "status": 404}
wire = fmt.finalize(e, data, 404)
assert wire == {"code": "not_found", "message": "app not found", "status": 404}
def test_422_maps_to_invalid_param(self, fmt):
e = UnprocessableEntity("workspace_id is required for name-based lookup")
data = {"code": "unprocessable_entity", "message": e.description, "status": 422}
wire = fmt.finalize(e, data, 422)
assert wire["code"] == "invalid_param"
def test_flask_restx_abort_data_path_yields_canonical_body(self, fmt):
# Simulates _contract.py's abort(422, message=..., errors=...): flask_restx
# attaches kwargs to e.data, which handle_error would otherwise put on the
# wire verbatim (no code/status).
e = UnprocessableEntity()
e.data = {
"message": "Request validation failed",
"errors": [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1", "extra": "drop me"}],
}
data = {"code": "unprocessable_entity", "message": e.description, "status": 422}
wire = fmt.finalize(e, data, 422)
assert wire["code"] == "invalid_param"
assert wire["message"] == "Request validation failed"
assert wire["status"] == 422
assert wire["details"] == [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1"}]
# the override channel now carries the canonical body
assert e.data == wire
def test_finalize_is_idempotent(self, fmt):
e = UnprocessableEntity()
e.data = {
"message": "Request validation failed",
"errors": [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1"}],
}
data = {"code": "unprocessable_entity", "message": e.description, "status": 422}
first = fmt.finalize(e, data, 422)
second = fmt.finalize(e, data, 422)
assert second == first
def test_base_http_exception_error_code_wins_over_status_map(self, fmt):
e = ProviderQuotaExceededError()
data = dict(e.data)
wire = fmt.finalize(e, data, 400)
assert wire["code"] == "provider_quota_exceeded"
assert wire["status"] == 400
def test_hint_attribute_is_emitted(self, fmt):
e = Conflict("seat limit")
e.hint = "remove a member first"
data = {"code": "conflict", "message": "seat limit", "status": 409}
wire = fmt.finalize(e, data, 409)
assert wire["hint"] == "remove a member first"
def test_params_shape_becomes_details(self, fmt):
e = ValueError("is required")
data = {"code": "invalid_param", "message": "is required", "params": "email", "status": 400}
wire = fmt.finalize(e, data, 400)
assert "params" not in wire
assert wire["details"] == [{"type": "invalid", "loc": ["email"], "msg": "is required"}]
def test_catch_all_exception_never_leaks_str_e(self, fmt):
e = RuntimeError("postgres password=hunter2 connection refused")
data = {"message": str(e), "code": "unknown", "status": 500}
wire = fmt.finalize(e, data, 500)
assert wire["code"] == "internal_server_error"
assert "hunter2" not in wire["message"]
def test_unmapped_status_falls_back_to_unknown(self, fmt):
from werkzeug.exceptions import Gone
e = Gone()
data = {"code": "gone", "message": e.description, "status": 410}
wire = fmt.finalize(e, data, 410)
assert wire["code"] == "unknown"
def test_openapi_error_subclass_is_throw_and_done(self, fmt):
# The dedicated throwable: subclass declares status + code + message once,
# call sites just `raise`; the formatter emits everything verbatim.
class TeapotError(OpenApiError):
code = 418
error_code = OpenApiErrorCode.INVALID_PARAM
description = "kettle says no"
e = TeapotError(details=[ErrorDetail(type="invalid", loc=["kettle"], msg="too hot")])
data = {"code": "im_a_teapot", "message": e.description, "status": 418}
wire = fmt.finalize(e, data, 418)
assert wire["code"] == OpenApiErrorCode.INVALID_PARAM
assert wire["message"] == TeapotError.description
assert wire["details"] == [{"type": "invalid", "loc": ["kettle"], "msg": "too hot"}]
def test_openapi_error_message_override(self, fmt):
e = OpenApiError("custom reason")
data = {"code": "bad_request", "message": e.description, "status": 400}
wire = fmt.finalize(e, data, 400)
assert wire["message"] == "custom reason"
assert wire["code"] == "bad_request"
def test_every_emitted_code_is_an_enum_member(self, fmt):
# Guard against the formatter inventing codes outside the contract.
cases = [
(NotFound("x"), {"code": "not_found", "message": "x", "status": 404}, 404),
(ProviderQuotaExceededError(), dict(ProviderQuotaExceededError().data), 400),
(ValueError("x"), {"code": "invalid_param", "message": "x", "status": 400}, 400),
]
for e, data, status in cases:
wire = fmt.finalize(e, data, status)
assert wire["code"] in {c.value for c in OpenApiErrorCode}