mirror of
https://github.com/langgenius/dify.git
synced 2026-06-16 22:11:09 +08:00
feat(openapi): add OpenApiErrorFormatter normalizing all error paths to ErrorBody
This commit is contained in:
parent
24b6e6f983
commit
41f827b609
@ -9,8 +9,12 @@ pre-existing ``e.data`` override the registered handler's return value.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from libs.external_api import http_status_message
|
||||
|
||||
|
||||
class OpenApiErrorCode(StrEnum):
|
||||
@ -64,3 +68,129 @@ class ErrorBody(BaseModel):
|
||||
status: int
|
||||
hint: str | None = None
|
||||
details: list[ErrorDetail] | None = None
|
||||
|
||||
|
||||
_CODE_BY_STATUS: dict[int, OpenApiErrorCode] = {
|
||||
400: OpenApiErrorCode.BAD_REQUEST,
|
||||
401: OpenApiErrorCode.UNAUTHORIZED,
|
||||
403: OpenApiErrorCode.FORBIDDEN,
|
||||
404: OpenApiErrorCode.NOT_FOUND,
|
||||
405: OpenApiErrorCode.METHOD_NOT_ALLOWED,
|
||||
406: OpenApiErrorCode.NOT_ACCEPTABLE,
|
||||
409: OpenApiErrorCode.CONFLICT,
|
||||
413: OpenApiErrorCode.REQUEST_TOO_LARGE,
|
||||
415: OpenApiErrorCode.UNSUPPORTED_MEDIA_TYPE,
|
||||
422: OpenApiErrorCode.INVALID_PARAM,
|
||||
429: OpenApiErrorCode.TOO_MANY_REQUESTS,
|
||||
500: OpenApiErrorCode.INTERNAL_ERROR,
|
||||
502: OpenApiErrorCode.BAD_GATEWAY,
|
||||
}
|
||||
|
||||
_GENERIC_500_MESSAGE = "Internal Server Error"
|
||||
|
||||
|
||||
class OpenApiError(HTTPException):
|
||||
"""Dedicated throwable for the /openapi/v1 surface.
|
||||
|
||||
A subclass declares ``code`` (HTTP status), ``error_code`` and
|
||||
``description`` exactly once; call sites just ``raise SomeError()`` —
|
||||
no per-site dict building, no duplicated message constants. The
|
||||
formatter emits all three (plus optional ``hint``/``details``) verbatim.
|
||||
"""
|
||||
|
||||
code = 400
|
||||
error_code: OpenApiErrorCode = OpenApiErrorCode.UNKNOWN
|
||||
hint: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str | None = None,
|
||||
*,
|
||||
hint: str | None = None,
|
||||
details: list[ErrorDetail] | None = None,
|
||||
) -> None:
|
||||
super().__init__(description=message)
|
||||
if hint is not None:
|
||||
self.hint = hint
|
||||
self.details = details
|
||||
|
||||
|
||||
class OpenApiErrorFormatter:
|
||||
"""Builds the canonical ErrorBody from whatever the shared handlers computed.
|
||||
|
||||
Resolution order for ``code``: explicit ``error_code`` class attribute
|
||||
(BaseHTTPException subclasses and OpenApiError subclasses) → HTTP status
|
||||
map → ``unknown``. Class-name-derived codes from the shared handler are
|
||||
deliberately ignored — they are not a stable contract.
|
||||
"""
|
||||
|
||||
def finalize(self, e: Exception, data: dict[str, Any], status_code: int) -> dict[str, Any]:
|
||||
exc_data = getattr(e, "data", None)
|
||||
merged: dict[str, Any] = {**data, **exc_data} if isinstance(exc_data, dict) else dict(data)
|
||||
|
||||
body = ErrorBody(
|
||||
code=self._resolve_code(e, status_code),
|
||||
message=self._resolve_message(merged, status_code),
|
||||
status=status_code,
|
||||
hint=self._resolve_hint(e),
|
||||
details=self._extract_details(e, merged),
|
||||
)
|
||||
wire = body.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
# flask-restx Api.handle_error does `data = getattr(e, "data", default_data)`
|
||||
# AFTER our handler returns, so a pre-existing e.data (flask_restx.abort,
|
||||
# BaseHTTPException) would override the canonical body. Rewrite it.
|
||||
try:
|
||||
e.data = wire # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
pass
|
||||
return wire
|
||||
|
||||
def _resolve_code(self, e: Exception, status_code: int) -> str:
|
||||
explicit = getattr(type(e), "error_code", None)
|
||||
if isinstance(explicit, (OpenApiErrorCode, str)) and str(explicit) != "unknown":
|
||||
return str(explicit)
|
||||
return str(_CODE_BY_STATUS.get(status_code, OpenApiErrorCode.UNKNOWN))
|
||||
|
||||
def _resolve_message(self, merged: dict[str, Any], status_code: int) -> str:
|
||||
if status_code >= 500:
|
||||
return _GENERIC_500_MESSAGE
|
||||
message = merged.get("message")
|
||||
if isinstance(message, str) and message:
|
||||
return message
|
||||
return http_status_message(status_code) or "request failed"
|
||||
|
||||
def _resolve_hint(self, e: Exception) -> str | None:
|
||||
hint = getattr(e, "hint", None)
|
||||
return hint if isinstance(hint, str) and hint else None
|
||||
|
||||
def _extract_details(self, e: Exception, merged: dict[str, Any]) -> list[ErrorDetail] | None:
|
||||
explicit = getattr(e, "details", None)
|
||||
if isinstance(explicit, list) and explicit and all(isinstance(d, ErrorDetail) for d in explicit):
|
||||
return explicit
|
||||
# an already-canonical body (e.g. e.data rewritten by a prior finalize)
|
||||
# carries "details"; re-validate so finalize stays idempotent
|
||||
canonical = merged.get("details")
|
||||
if isinstance(canonical, list) and canonical and all(isinstance(d, dict) for d in canonical):
|
||||
return [ErrorDetail.model_validate(d) for d in canonical]
|
||||
errors = merged.get("errors")
|
||||
if isinstance(errors, list) and errors:
|
||||
details = [
|
||||
ErrorDetail(
|
||||
type=str(item.get("type", "invalid")),
|
||||
loc=[part for part in item.get("loc", []) if self._is_loc_part(part)],
|
||||
msg=str(item.get("msg", "")),
|
||||
)
|
||||
for item in errors
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
return details or None
|
||||
params = merged.get("params")
|
||||
if isinstance(params, str) and params:
|
||||
return [ErrorDetail(type="invalid", loc=[params], msg=str(merged.get("message", "")))]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_loc_part(part: Any) -> bool:
|
||||
# bool is an int subclass but is not a valid path segment
|
||||
return isinstance(part, (str, int)) and not isinstance(part, bool)
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
"""Wire-contract tests for the canonical /openapi/v1 error body."""
|
||||
|
||||
from controllers.openapi._errors import ErrorBody, ErrorDetail, OpenApiErrorCode
|
||||
import pytest
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi._errors import ErrorBody, ErrorDetail, OpenApiError, OpenApiErrorCode, OpenApiErrorFormatter
|
||||
from controllers.web.error import ProviderQuotaExceededError
|
||||
|
||||
|
||||
class TestErrorBodyModel:
|
||||
@ -31,3 +35,141 @@ class TestErrorBodyModel:
|
||||
body = ErrorBody.model_validate({"code": "some_future_code", "message": "x", "status": 400})
|
||||
|
||||
assert body.code == "some_future_code"
|
||||
|
||||
|
||||
class TestOpenApiErrorFormatter:
|
||||
@pytest.fixture
|
||||
def fmt(self):
|
||||
return OpenApiErrorFormatter()
|
||||
|
||||
def test_plain_werkzeug_exception_maps_code_from_status(self, fmt):
|
||||
e = NotFound("app not found")
|
||||
data = {"code": "not_found", "message": "app not found", "status": 404}
|
||||
|
||||
wire = fmt.finalize(e, data, 404)
|
||||
|
||||
assert wire == {"code": "not_found", "message": "app not found", "status": 404}
|
||||
|
||||
def test_422_maps_to_invalid_param(self, fmt):
|
||||
e = UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
data = {"code": "unprocessable_entity", "message": e.description, "status": 422}
|
||||
|
||||
wire = fmt.finalize(e, data, 422)
|
||||
|
||||
assert wire["code"] == "invalid_param"
|
||||
|
||||
def test_flask_restx_abort_data_path_yields_canonical_body(self, fmt):
|
||||
# Simulates _contract.py's abort(422, message=..., errors=...): flask_restx
|
||||
# attaches kwargs to e.data, which handle_error would otherwise put on the
|
||||
# wire verbatim (no code/status).
|
||||
e = UnprocessableEntity()
|
||||
e.data = {
|
||||
"message": "Request validation failed",
|
||||
"errors": [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1", "extra": "drop me"}],
|
||||
}
|
||||
data = {"code": "unprocessable_entity", "message": e.description, "status": 422}
|
||||
|
||||
wire = fmt.finalize(e, data, 422)
|
||||
|
||||
assert wire["code"] == "invalid_param"
|
||||
assert wire["message"] == "Request validation failed"
|
||||
assert wire["status"] == 422
|
||||
assert wire["details"] == [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1"}]
|
||||
# the override channel now carries the canonical body
|
||||
assert e.data == wire
|
||||
|
||||
def test_finalize_is_idempotent(self, fmt):
|
||||
e = UnprocessableEntity()
|
||||
e.data = {
|
||||
"message": "Request validation failed",
|
||||
"errors": [{"type": "int_parsing", "loc": ["page"], "msg": "must be >= 1"}],
|
||||
}
|
||||
data = {"code": "unprocessable_entity", "message": e.description, "status": 422}
|
||||
|
||||
first = fmt.finalize(e, data, 422)
|
||||
second = fmt.finalize(e, data, 422)
|
||||
|
||||
assert second == first
|
||||
|
||||
def test_base_http_exception_error_code_wins_over_status_map(self, fmt):
|
||||
e = ProviderQuotaExceededError()
|
||||
data = dict(e.data)
|
||||
|
||||
wire = fmt.finalize(e, data, 400)
|
||||
|
||||
assert wire["code"] == "provider_quota_exceeded"
|
||||
assert wire["status"] == 400
|
||||
|
||||
def test_hint_attribute_is_emitted(self, fmt):
|
||||
e = Conflict("seat limit")
|
||||
e.hint = "remove a member first"
|
||||
data = {"code": "conflict", "message": "seat limit", "status": 409}
|
||||
|
||||
wire = fmt.finalize(e, data, 409)
|
||||
|
||||
assert wire["hint"] == "remove a member first"
|
||||
|
||||
def test_params_shape_becomes_details(self, fmt):
|
||||
e = ValueError("is required")
|
||||
data = {"code": "invalid_param", "message": "is required", "params": "email", "status": 400}
|
||||
|
||||
wire = fmt.finalize(e, data, 400)
|
||||
|
||||
assert "params" not in wire
|
||||
assert wire["details"] == [{"type": "invalid", "loc": ["email"], "msg": "is required"}]
|
||||
|
||||
def test_catch_all_exception_never_leaks_str_e(self, fmt):
|
||||
e = RuntimeError("postgres password=hunter2 connection refused")
|
||||
data = {"message": str(e), "code": "unknown", "status": 500}
|
||||
|
||||
wire = fmt.finalize(e, data, 500)
|
||||
|
||||
assert wire["code"] == "internal_server_error"
|
||||
assert "hunter2" not in wire["message"]
|
||||
|
||||
def test_unmapped_status_falls_back_to_unknown(self, fmt):
|
||||
from werkzeug.exceptions import Gone
|
||||
|
||||
e = Gone()
|
||||
data = {"code": "gone", "message": e.description, "status": 410}
|
||||
|
||||
wire = fmt.finalize(e, data, 410)
|
||||
|
||||
assert wire["code"] == "unknown"
|
||||
|
||||
def test_openapi_error_subclass_is_throw_and_done(self, fmt):
|
||||
# The dedicated throwable: subclass declares status + code + message once,
|
||||
# call sites just `raise`; the formatter emits everything verbatim.
|
||||
class TeapotError(OpenApiError):
|
||||
code = 418
|
||||
error_code = OpenApiErrorCode.INVALID_PARAM
|
||||
description = "kettle says no"
|
||||
|
||||
e = TeapotError(details=[ErrorDetail(type="invalid", loc=["kettle"], msg="too hot")])
|
||||
data = {"code": "im_a_teapot", "message": e.description, "status": 418}
|
||||
|
||||
wire = fmt.finalize(e, data, 418)
|
||||
|
||||
assert wire["code"] == OpenApiErrorCode.INVALID_PARAM
|
||||
assert wire["message"] == TeapotError.description
|
||||
assert wire["details"] == [{"type": "invalid", "loc": ["kettle"], "msg": "too hot"}]
|
||||
|
||||
def test_openapi_error_message_override(self, fmt):
|
||||
e = OpenApiError("custom reason")
|
||||
data = {"code": "bad_request", "message": e.description, "status": 400}
|
||||
|
||||
wire = fmt.finalize(e, data, 400)
|
||||
|
||||
assert wire["message"] == "custom reason"
|
||||
assert wire["code"] == "bad_request"
|
||||
|
||||
def test_every_emitted_code_is_an_enum_member(self, fmt):
|
||||
# Guard against the formatter inventing codes outside the contract.
|
||||
cases = [
|
||||
(NotFound("x"), {"code": "not_found", "message": "x", "status": 404}, 404),
|
||||
(ProviderQuotaExceededError(), dict(ProviderQuotaExceededError().data), 400),
|
||||
(ValueError("x"), {"code": "invalid_param", "message": "x", "status": 400}, 400),
|
||||
]
|
||||
for e, data, status in cases:
|
||||
wire = fmt.finalize(e, data, status)
|
||||
assert wire["code"] in {c.value for c in OpenApiErrorCode}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user