feat(api): Flask-RESTX response() vs actual return value checker (#36488)

This commit is contained in:
chariri 2026-05-22 00:05:06 +09:00 committed by GitHub
parent 092c8bca81
commit f19702f76c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 865 additions and 2 deletions

View File

@ -75,13 +75,19 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..."
@uv run --project api --dev ruff format ./api
@uv run --project api --dev ruff check --fix ./api
@$(MAKE) api-contract-lint
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
api-contract-lint:
@echo "🔎 Linting Flask response contracts..."
@uv run --project api --dev python api/dev/lint_response_contracts.py
@echo "✅ Response contract lint complete"
type-check:
@echo "📝 Running type checks (pyrefly + mypy)..."
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
@ -191,6 +197,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make api-contract-lint - Check Flask response docs against returned schemas"
@echo " make type-check - Run type checks (pyrefly, mypy)"
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@ -204,4 +211,4 @@ help:
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all

View File

@ -195,6 +195,7 @@ Before opening a PR / submitting:
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Document non-obvious behaviour with concise docstrings and comments.
- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload.
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,

View File

@ -0,0 +1,664 @@
"""Lint Flask-RESTX response docs against statically visible response serializers.
This checker intentionally stays conservative. It only reports a hard schema
mismatch when both sides are statically known for the same 2xx status code:
a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)``
or ``Model.model_validate(...).model_dump()`` return.
Raw dictionaries, raw lists, ``None`` responses, streaming helpers, missing
response schemas, and returns with non-literal status codes are classified as
unknown so reviewers can triage them without blocking unrelated work. The one
intentional non-schema mismatch is a known body/schema on a no-body status such
as 204, 205, or 304.
"""
from __future__ import annotations
import argparse
import ast
import json
import sys
from collections import Counter, defaultdict
from collections.abc import Iterable, Sequence
from dataclasses import asdict, dataclass, field
from http import HTTPStatus
from pathlib import Path
from typing import Any, Literal
HTTP_METHODS = {"delete", "get", "head", "options", "patch", "post", "put"}
NO_BODY_STATUSES = {HTTPStatus.NO_CONTENT.value, HTTPStatus.RESET_CONTENT.value, HTTPStatus.NOT_MODIFIED.value}
DEFAULT_CONTROLLER_DIRS = ("controllers/console", "controllers/service_api", "controllers/web")
type Classification = Literal["valid", "mismatch", "unknown", "refactorable"]
type ActualKind = Literal[
"empty",
"model",
"model_dump_variable",
"none",
"raw_dict",
"raw_list",
"raw_value",
"unknown",
]
type MethodNode = ast.FunctionDef | ast.AsyncFunctionDef
HTTP_STATUS_NAMES = {status.name: status.value for status in HTTPStatus}
HTTP_STATUS_NAMES.update({f"HTTP_{status.value}_{status.name}": status.value for status in HTTPStatus})
@dataclass(frozen=True)
class DocumentedResponse:
status: int
model: str | None
line: int
@dataclass(frozen=True)
class ActualResponse:
status: int | None
kind: ActualKind
model: str | None
line: int
@dataclass(frozen=True)
class ContractCheck:
classification: Classification
file: str
class_name: str
method: str
route: str
line: int
reason: str
documented: dict[int, str | None]
actual: list[ActualResponse]
@dataclass(frozen=True)
class ContractCheckContext:
"""Stable route-method context shared by every classification result."""
file: str
class_name: str
method: str
route: str
line: int
documented: dict[int, str | None]
def build(
self, classification: Classification, reason: str, actual_responses: Sequence[ActualResponse]
) -> ContractCheck:
return ContractCheck(
classification=classification,
file=self.file,
class_name=self.class_name,
method=self.method,
route=self.route,
line=self.line,
reason=reason,
documented=self.documented,
actual=list(actual_responses),
)
def mismatch(self, reason: str, documented: DocumentedResponse, actual: ActualResponse) -> ContractCheck:
return self.build("mismatch", f"{reason} (doc line {documented.line}, return line {actual.line})", [actual])
@dataclass
class VariableAssignmentSummary:
"""Track whether a local name is safe to treat as one specific response model."""
known_models: set[str] = field(default_factory=set)
has_unknown_assignment: bool = False
def add_known(self, model: str) -> None:
self.known_models.add(model)
def add_unknown(self) -> None:
self.has_unknown_assignment = True
def single_known_model(self) -> str | None:
if self.has_unknown_assignment or len(self.known_models) != 1:
return None
return next(iter(self.known_models))
def dotted_name(node: ast.AST) -> str | None:
match node:
case ast.Name():
return node.id
case ast.Attribute():
parent = dotted_name(node.value)
if parent:
return f"{parent}.{node.attr}"
return node.attr
return None
def leaf_name(node: ast.AST) -> str | None:
name = dotted_name(node)
if name is None:
return None
return name.rsplit(".", 1)[-1]
def int_constant(node: ast.AST | None) -> int | None:
if isinstance(node, ast.Constant) and isinstance(node.value, int):
return node.value
if isinstance(node, ast.Name):
return HTTP_STATUS_NAMES.get(node.id)
if isinstance(node, ast.Attribute):
return HTTP_STATUS_NAMES.get(node.attr)
return None
def string_constant(node: ast.AST | None) -> str | None:
if isinstance(node, ast.Constant) and isinstance(node.value, str):
return node.value
return None
def keyword_value(call: ast.Call, *names: str) -> ast.AST | None:
for keyword in call.keywords:
if keyword.arg in names:
return keyword.value
return None
def is_probable_model_name(name: str) -> bool:
return bool(name) and name[0].isupper()
def model_name_from_schema_expr(node: ast.AST | None) -> str | None:
if node is None:
return None
if isinstance(node, ast.Subscript):
value_name = dotted_name(node.value)
if value_name and value_name.endswith(".models"):
# register_response_schema_models stores schemas by model name; both
# ns.models[Model.__name__] and ns.models["Model"] appear in controllers.
key = node.slice
if isinstance(key, ast.Attribute) and key.attr == "__name__":
return leaf_name(key.value)
return string_constant(key)
if isinstance(node, ast.Call):
func_name = dotted_name(node.func)
if func_name and func_name.endswith(".model"):
return string_constant(node.args[0] if node.args else keyword_value(node, "name"))
if isinstance(node, ast.Name):
return node.id
return None
def documented_response_from_decorator(decorator: ast.expr) -> DocumentedResponse | None:
if not isinstance(decorator, ast.Call):
return None
func_name = dotted_name(decorator.func)
if not func_name or not func_name.endswith(".response"):
return None
status_expr = decorator.args[0] if decorator.args else keyword_value(decorator, "code", "status")
status = int_constant(status_expr)
if status is None:
return None
schema_expr: ast.AST | None = decorator.args[2] if len(decorator.args) >= 3 else None
schema_expr = keyword_value(decorator, "model", "schema") or schema_expr
return DocumentedResponse(
status=status,
model=model_name_from_schema_expr(schema_expr),
line=decorator.lineno,
)
def route_from_decorator(decorator: ast.expr) -> str | None:
if not isinstance(decorator, ast.Call):
return None
func_name = dotted_name(decorator.func)
if not func_name or not func_name.endswith(".route"):
return None
return string_constant(decorator.args[0] if decorator.args else keyword_value(decorator, "route", "path"))
def routes_from_decorators(decorators: Iterable[ast.expr]) -> list[str]:
return [route for decorator in decorators if (route := route_from_decorator(decorator))]
def response_docs_from_decorators(decorators: Iterable[ast.expr]) -> dict[int, DocumentedResponse]:
docs: dict[int, DocumentedResponse] = {}
for decorator in decorators:
response = documented_response_from_decorator(decorator)
if response and 200 <= response.status < 300:
docs[response.status] = response
return docs
def model_name_from_model_validate_call(node: ast.AST) -> str | None:
if not isinstance(node, ast.Call):
return None
if isinstance(node.func, ast.Attribute) and node.func.attr == "model_validate":
return leaf_name(node.func.value)
return None
def model_name_from_constructor_call(node: ast.AST) -> str | None:
if not isinstance(node, ast.Call):
return None
if isinstance(node.func, ast.Name) and is_probable_model_name(node.func.id):
return node.func.id
return None
def model_name_from_model_dump(node: ast.AST) -> str | None:
if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute) or node.func.attr != "model_dump":
return None
dumped_value = node.func.value
if isinstance(dumped_value, ast.Call):
return model_name_from_model_validate_call(dumped_value) or model_name_from_constructor_call(dumped_value)
return None
def model_name_from_model_value(node: ast.AST) -> str | None:
return model_name_from_model_validate_call(node) or model_name_from_constructor_call(node)
def model_name_from_dump_response(node: ast.AST) -> str | None:
if not isinstance(node, ast.Call):
return None
func_name = dotted_name(node.func)
if func_name != "dump_response" and not (func_name and func_name.endswith(".dump_response")):
return None
model_expr = node.args[0] if node.args else keyword_value(node, "model", "schema", "response_model")
if isinstance(model_expr, ast.Name):
return model_expr.id
return None
def actual_kind_from_expr(
expr: ast.AST | None, variable_models: dict[str, str] | None = None
) -> tuple[ActualKind, str | None]:
if expr is None:
return "none", None
dump_response_model = model_name_from_dump_response(expr)
if dump_response_model:
return "model", dump_response_model
if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "model_dump":
dumped_value = expr.func.value
if isinstance(dumped_value, ast.Name) and variable_models:
# A variable dump can match today, but it bypasses dump_response and
# is easier to drift; keep it visible as refactorable.
model_name = variable_models.get(dumped_value.id)
if model_name:
return "model_dump_variable", model_name
model_dump_model = model_name_from_model_dump(expr)
if model_dump_model:
return "model", model_dump_model
if isinstance(expr, ast.Constant):
if expr.value is None:
return "none", None
if expr.value == "":
return "empty", None
return "raw_value", None
if isinstance(expr, ast.Dict):
return "raw_dict", None
if isinstance(expr, ast.List):
return "raw_list", None
return "unknown", None
def actual_response_from_return(return_node: ast.Return, variable_models: dict[str, str]) -> ActualResponse:
status: int | None = 200
body_expr = return_node.value
if isinstance(return_node.value, ast.Tuple) and return_node.value.elts:
body_expr = return_node.value.elts[0]
if len(return_node.value.elts) >= 2:
# Dynamic statuses are not safe to coerce to 200; classify them as unknown.
status = int_constant(return_node.value.elts[1])
kind, model = actual_kind_from_expr(body_expr, variable_models)
return ActualResponse(status=status, kind=kind, model=model, line=return_node.lineno)
def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]:
"""Yield method body nodes while ignoring nested function/class scopes."""
stack: list[ast.AST] = list(reversed(method.body))
while stack:
node = stack.pop()
yield node
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda | ast.ClassDef):
continue
stack.extend(reversed(list(ast.iter_child_nodes(node))))
def target_names(target: ast.AST) -> Iterable[str]:
if isinstance(target, ast.Name):
yield target.id
elif isinstance(target, ast.Tuple | ast.List):
for item in target.elts:
yield from target_names(item)
def record_assignment(
assignments: defaultdict[str, VariableAssignmentSummary], targets: Iterable[str], model_name: str | None
) -> None:
for target in targets:
if model_name is None:
# Once a name receives an unknown value, later model_dump() calls on it
# are no longer a reliable signal for the returned schema.
assignments[target].add_unknown()
else:
assignments[target].add_known(model_name)
def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]:
"""Infer local variables that are unambiguously assigned one response model."""
assignments: defaultdict[str, VariableAssignmentSummary] = defaultdict(VariableAssignmentSummary)
for node in iter_method_nodes(method):
match node:
case ast.Assign(targets=targets, value=value):
record_assignment(
assignments,
(name for target in targets for name in target_names(target)),
model_name_from_model_value(value),
)
case ast.AnnAssign(target=target, value=value) if value is not None:
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
case ast.AugAssign(target=target) | ast.For(target=target) | ast.AsyncFor(target=target):
# Mutation and loop targets overwrite prior values with runtime-dependent data.
record_assignment(assignments, target_names(target), None)
case ast.With(items=items) | ast.AsyncWith(items=items):
for item in items:
if item.optional_vars is not None:
record_assignment(assignments, target_names(item.optional_vars), None)
case ast.ExceptHandler(name=name) if name:
assignments[name].add_unknown()
case ast.NamedExpr(target=target, value=value):
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
return {name: model for name, summary in assignments.items() if (model := summary.single_known_model()) is not None}
def actual_responses_for_method(method: MethodNode) -> list[ActualResponse]:
"""Extract statically visible 2xx returns from one controller method.
The analysis is deliberately shallow and conservative:
1. Walk only the method's own body, skipping nested functions/classes.
2. Infer local variables that are assigned exactly one recognizable response
model, so ``response.model_dump()`` can still be connected to its schema.
3. Treat any later unknown assignment, mutation target, loop target, context
manager binding, or exception binding as invalidating that variable.
4. For each top-level return path, split Flask-style ``(body, status)``
tuples, classify the body expression, and keep non-literal statuses as
``None`` so the classifier reports them as unknown instead of assuming 200.
5. Drop non-2xx literal statuses, since response contracts here only compare
successful response schemas.
"""
variable_models = variable_model_assignments_for_method(method)
responses: list[ActualResponse] = []
for node in iter_method_nodes(method):
if isinstance(node, ast.Return):
responses.append(actual_response_from_return(node, variable_models))
return [response for response in responses if response.status is None or 200 <= response.status < 300]
def display_path(file_path: Path, repo_root: Path) -> str:
try:
return str(file_path.relative_to(repo_root))
except ValueError:
return str(file_path)
def classify_method(
*,
actual_responses: Sequence[ActualResponse],
class_name: str,
documented_responses: dict[int, DocumentedResponse],
file_path: Path,
method: MethodNode,
repo_root: Path,
route: str,
) -> ContractCheck:
documented_summary = {status: response.model for status, response in sorted(documented_responses.items())}
context = ContractCheckContext(
file=display_path(file_path, repo_root),
class_name=class_name,
method=method.name,
route=route,
line=method.lineno,
documented=documented_summary,
)
if not actual_responses:
return context.build("unknown", "no statically visible 2xx return", [])
unknown_reasons: list[str] = []
refactorable_reasons: list[str] = []
for actual in actual_responses:
if actual.status is None:
unknown_reasons.append(f"return line {actual.line} has non-literal or unsupported status")
continue
documented = documented_responses.get(actual.status)
if actual.status in NO_BODY_STATUSES:
# No-body statuses are contract violations even when the schema names
# would otherwise match, because clients should not expect a payload.
if documented is not None and documented.model is not None:
return context.mismatch(
f"status {actual.status} is a no-body response but documents {documented.model}",
documented,
actual,
)
if actual.kind in {"model", "model_dump_variable", "raw_dict", "raw_list", "raw_value"}:
no_body_doc = DocumentedResponse(status=actual.status, model=None, line=method.lineno)
return context.mismatch(
f"status {actual.status} is a no-body response but returns {actual.kind}",
no_body_doc,
actual,
)
if actual.kind == "unknown":
unknown_reasons.append(f"status {actual.status} returns unknown body expression")
continue
if documented is None:
unknown_reasons.append(f"status {actual.status} has no @response doc")
continue
if documented.model is None:
unknown_reasons.append(f"status {actual.status} response doc has no schema model")
continue
if actual.kind == "model_dump_variable" and actual.model is not None:
if documented.model != actual.model:
return context.mismatch(
f"status {actual.status} documents {documented.model} but returns {actual.model}",
documented,
actual,
)
# The schema matches, but this path still deserves cleanup because
# dump_response is the contract-aware serialization helper.
refactorable_reasons.append(
f"status {actual.status} returns {actual.model}.model_dump() through a variable; prefer dump_response"
)
continue
if actual.kind != "model" or actual.model is None:
unknown_reasons.append(f"status {actual.status} returns {actual.kind}")
continue
if documented.model != actual.model:
return context.mismatch(
f"status {actual.status} documents {documented.model} but returns {actual.model}",
documented,
actual,
)
if unknown_reasons:
# Unknown beats refactorable: if any return path is ambiguous, do not
# imply the endpoint is merely a cleanup candidate.
return context.build("unknown", "; ".join(sorted(set(unknown_reasons))), actual_responses)
if refactorable_reasons:
return context.build("refactorable", "; ".join(sorted(set(refactorable_reasons))), actual_responses)
return context.build(
"valid",
"documented response schema matches statically visible return schema",
actual_responses,
)
def iter_controller_files(paths: Iterable[Path]) -> Iterable[Path]:
for path in paths:
if path.is_file() and path.suffix == ".py":
yield path
elif path.is_dir():
yield from sorted(child for child in path.rglob("*.py") if child.is_file())
def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]:
module = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path))
checks: list[ContractCheck] = []
for node in module.body:
if not isinstance(node, ast.ClassDef):
continue
class_routes = routes_from_decorators(node.decorator_list)
class_documented = response_docs_from_decorators(node.decorator_list)
for item in node.body:
if not isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) or item.name not in HTTP_METHODS:
continue
routes = routes_from_decorators(item.decorator_list) or class_routes
if not routes:
continue
documented = {**class_documented, **response_docs_from_decorators(item.decorator_list)}
# Method-level @response decorators override class-level defaults for
# the same status code, matching Flask-RESTX's common controller style.
actual = actual_responses_for_method(item)
for route in routes:
checks.append(
classify_method(
actual_responses=actual,
class_name=node.name,
documented_responses=documented,
file_path=file_path,
method=item,
repo_root=repo_root,
route=route,
)
)
return checks
def as_jsonable(check: ContractCheck) -> dict[str, Any]:
data = asdict(check)
data["documented"] = {str(status): model for status, model in check.documented.items()}
return data
def print_text_report(checks: Sequence[ContractCheck], *, include_valid: bool) -> None:
counts = Counter(check.classification for check in checks)
sys.stdout.write(
"Response contract lint: "
f"{counts['valid']} valid, "
f"{counts['mismatch']} mismatch, "
f"{counts['refactorable']} refactorable, "
f"{counts['unknown']} unknown\n"
)
for classification in ("mismatch", "refactorable", "unknown", "valid"):
filtered = [check for check in checks if check.classification == classification]
if classification == "valid" and not include_valid:
continue
if not filtered:
continue
sys.stdout.write(f"\n{classification.upper()}:\n")
for check in filtered:
location = f"{check.file}:{check.line} {check.class_name}.{check.method.upper()} {check.route}"
sys.stdout.write(f"- {location}: {check.reason}\n")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"paths",
nargs="*",
help="Files or directories to lint. Defaults to Flask controller directories.",
)
parser.add_argument("--include-valid", action="store_true", help="Print valid route methods in text output.")
parser.add_argument("--json", action="store_true", help="Emit machine-readable JSON.")
parser.add_argument(
"--fail-on-mismatch",
action="store_true",
help="Treat mismatched response contracts as failures. By default this linter is report-only.",
)
parser.add_argument(
"--fail-on-unknown",
action="store_true",
help="Treat unknown route methods as failures. By default this linter is report-only.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
api_root = Path(__file__).resolve().parents[1]
repo_root = api_root.parent
raw_paths = args.paths or list(DEFAULT_CONTROLLER_DIRS)
paths = [path if path.is_absolute() else api_root / path for path in map(Path, raw_paths)]
checks: list[ContractCheck] = []
for file_path in iter_controller_files(paths):
checks.extend(checks_for_file(file_path.resolve(), repo_root))
checks.sort(key=lambda check: (check.classification, check.file, check.line, check.method))
if args.json:
grouped = defaultdict(list)
for check in checks:
grouped[check.classification].append(as_jsonable(check))
sys.stdout.write(f"{json.dumps(grouped, indent=2, sort_keys=True)}\n")
else:
print_text_report(checks, include_valid=bool(args.include_valid))
has_mismatch = any(check.classification == "mismatch" for check in checks)
has_unknown = any(check.classification == "unknown" for check in checks)
return int((bool(args.fail_on_mismatch) and has_mismatch) or (bool(args.fail_on_unknown) and has_unknown))
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,191 @@
import importlib.util
import sys
from pathlib import Path
def _load_lint_response_contracts_module():
api_dir = Path(__file__).parents[3]
script_path = api_dir / "dev" / "lint_response_contracts.py"
spec = importlib.util.spec_from_file_location("lint_response_contracts", script_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
def _checks_for_source(tmp_path: Path, source: str):
module = _load_lint_response_contracts_module()
controller_path = tmp_path / "controllers" / "sample.py"
controller_path.parent.mkdir()
controller_path.write_text(source, encoding="utf-8")
return module.checks_for_file(controller_path, tmp_path)
def test_no_body_status_with_body_is_mismatch_while_empty_body_is_valid(tmp_path: Path):
checks = _checks_for_source(
tmp_path,
"""
@ns.route("/bad")
class BadDeleteApi(Resource):
@ns.response(204, "Deleted")
def delete(self):
return {"result": "success"}, 204
@ns.route("/ok")
class EmptyDeleteApi(Resource):
@ns.response(204, "Deleted")
def delete(self):
return "", 204
""",
)
assert [(check.class_name, check.classification) for check in checks] == [
("BadDeleteApi", "mismatch"),
("EmptyDeleteApi", "valid"),
]
assert "no-body response but returns raw_dict" in checks[0].reason
def test_variable_model_dump_is_refactorable_not_valid(tmp_path: Path):
checks = _checks_for_source(
tmp_path,
"""
from http import HTTPStatus
@ns.route("/annotations")
class AnnotationApi(Resource):
@ns.response(HTTPStatus.CREATED, "Created", ns.models[AnnotationResponse.__name__])
def post(self):
if use_existing:
response = AnnotationResponse.model_validate(existing, from_attributes=True)
else:
response = AnnotationResponse(id="new")
return response.model_dump(mode="json"), HTTPStatus.CREATED
""",
)
assert len(checks) == 1
assert checks[0].classification == "refactorable"
assert checks[0].actual[0].status == 201
assert checks[0].actual[0].kind == "model_dump_variable"
assert "prefer dump_response" in checks[0].reason
def test_variable_model_dump_with_wrong_documented_schema_is_mismatch(tmp_path: Path):
checks = _checks_for_source(
tmp_path,
"""
@ns.route("/annotations")
class AnnotationApi(Resource):
@ns.response(200, "OK", ns.models[DocumentedResponse.__name__])
def get(self):
response = ActualResponse.model_validate(data)
return response.model_dump(mode="json"), 200
""",
)
assert len(checks) == 1
assert checks[0].classification == "mismatch"
assert "documents DocumentedResponse but returns ActualResponse" in checks[0].reason
def test_nested_returns_are_ignored_for_outer_control_flow(tmp_path: Path):
checks = _checks_for_source(
tmp_path,
"""
@ns.route("/stream")
class StreamApi(Resource):
@ns.response(200, "OK", ns.models[StreamResponse.__name__])
def get(self):
def generate_events():
return dump_response(WrongResponse, {"event": "nested"}), 200
if finished:
return dump_response(StreamResponse, {"event": "done"}), 200
return dump_response(StreamResponse, {"event": "running"}), 200
""",
)
assert len(checks) == 1
assert checks[0].classification == "valid"
assert {actual.model for actual in checks[0].actual} == {"StreamResponse"}
def test_main_is_report_only_by_default_for_mismatches(tmp_path: Path, monkeypatch):
module = _load_lint_response_contracts_module()
controller_path = tmp_path / "controllers" / "sample.py"
controller_path.parent.mkdir()
controller_path.write_text(
"""
@ns.route("/bad")
class BadDeleteApi(Resource):
@ns.response(204, "Deleted")
def delete(self):
return {"result": "success"}, 204
""",
encoding="utf-8",
)
monkeypatch.setattr(sys, "argv", ["lint_response_contracts.py", str(controller_path)])
assert module.main() == 0
monkeypatch.setattr(sys, "argv", ["lint_response_contracts.py", "--fail-on-mismatch", str(controller_path)])
assert module.main() == 1
def test_class_level_route_and_response_docs_apply_to_methods(tmp_path: Path):
checks = _checks_for_source(
tmp_path,
"""
@ns.route(path="/items")
@ns.response(code=200, description="OK", model=ns.models[ItemListResponse.__name__])
class ItemListApi(Resource):
def get(self):
return dump_response(ItemListResponse, {"data": []}), 200
""",
)
assert len(checks) == 1
assert checks[0].classification == "valid"
assert checks[0].route == "/items"
def test_unknown_reassignment_prevents_variable_model_dump_inference(tmp_path: Path):
checks = _checks_for_source(
tmp_path,
"""
@ns.route("/items")
class ItemApi(Resource):
@ns.response(200, "OK", ns.models[ItemResponse.__name__])
def get(self):
response = ItemResponse.model_validate(item)
if refresh:
response = load_response()
return response.model_dump(mode="json"), 200
""",
)
assert len(checks) == 1
assert checks[0].classification == "unknown"
assert "returns unknown" in checks[0].reason
def test_non_literal_status_is_unknown_not_defaulted_to_200(tmp_path: Path):
checks = _checks_for_source(
tmp_path,
"""
@ns.route("/items")
class ItemApi(Resource):
@ns.response(200, "OK", ns.models[ItemResponse.__name__])
def get(self):
return dump_response(ItemResponse, item), status_code
""",
)
assert len(checks) == 1
assert checks[0].classification == "unknown"
assert "non-literal or unsupported status" in checks[0].reason