From f19702f76cbe8c46e57c521b9cbf8073e39ef14b Mon Sep 17 00:00:00 2001 From: chariri Date: Fri, 22 May 2026 00:05:06 +0900 Subject: [PATCH] feat(api): Flask-RESTX `response()` vs actual return value checker (#36488) --- Makefile | 11 +- api/AGENTS.md | 1 + api/dev/lint_response_contracts.py | 664 ++++++++++++++++++ .../commands/test_lint_response_contracts.py | 191 +++++ 4 files changed, 865 insertions(+), 2 deletions(-) create mode 100644 api/dev/lint_response_contracts.py create mode 100644 api/tests/unit_tests/commands/test_lint_response_contracts.py diff --git a/Makefile b/Makefile index 9d3ac4ee47..be665e7123 100644 --- a/Makefile +++ b/Makefile @@ -75,13 +75,19 @@ check: @echo "✅ Code check complete" lint: - @echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..." + @echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..." @uv run --project api --dev ruff format ./api @uv run --project api --dev ruff check --fix ./api + @$(MAKE) api-contract-lint @uv run --directory api --dev lint-imports @uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example @echo "✅ Linting complete" +api-contract-lint: + @echo "🔎 Linting Flask response contracts..." + @uv run --project api --dev python api/dev/lint_response_contracts.py + @echo "✅ Response contract lint complete" + type-check: @echo "📝 Running type checks (pyrefly + mypy)..." @./dev/pyrefly-check-local $(PATH_TO_CHECK) @@ -191,6 +197,7 @@ help: @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" + @echo " make api-contract-lint - Check Flask response docs against returned schemas" @echo " make type-check - Run type checks (pyrefly, mypy)" @echo " make type-check-core - Run core type checks (pyrefly, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @@ -204,4 +211,4 @@ help: @echo " make build-push-all - Build and push all Docker images" # Phony targets -.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all +.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all diff --git a/api/AGENTS.md b/api/AGENTS.md index 4abd14e7c0..984322590b 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -195,6 +195,7 @@ Before opening a PR / submitting: - Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic. - Services: coordinate repositories, providers, background tasks; keep side effects explicit. - Document non-obvious behaviour with concise docstrings and comments. +- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload. - For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`. In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`, diff --git a/api/dev/lint_response_contracts.py b/api/dev/lint_response_contracts.py new file mode 100644 index 0000000000..6cdb3e289c --- /dev/null +++ b/api/dev/lint_response_contracts.py @@ -0,0 +1,664 @@ +"""Lint Flask-RESTX response docs against statically visible response serializers. + +This checker intentionally stays conservative. It only reports a hard schema +mismatch when both sides are statically known for the same 2xx status code: +a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)`` +or ``Model.model_validate(...).model_dump()`` return. + +Raw dictionaries, raw lists, ``None`` responses, streaming helpers, missing +response schemas, and returns with non-literal status codes are classified as +unknown so reviewers can triage them without blocking unrelated work. The one +intentional non-schema mismatch is a known body/schema on a no-body status such +as 204, 205, or 304. +""" + +from __future__ import annotations + +import argparse +import ast +import json +import sys +from collections import Counter, defaultdict +from collections.abc import Iterable, Sequence +from dataclasses import asdict, dataclass, field +from http import HTTPStatus +from pathlib import Path +from typing import Any, Literal + +HTTP_METHODS = {"delete", "get", "head", "options", "patch", "post", "put"} +NO_BODY_STATUSES = {HTTPStatus.NO_CONTENT.value, HTTPStatus.RESET_CONTENT.value, HTTPStatus.NOT_MODIFIED.value} +DEFAULT_CONTROLLER_DIRS = ("controllers/console", "controllers/service_api", "controllers/web") + +type Classification = Literal["valid", "mismatch", "unknown", "refactorable"] +type ActualKind = Literal[ + "empty", + "model", + "model_dump_variable", + "none", + "raw_dict", + "raw_list", + "raw_value", + "unknown", +] +type MethodNode = ast.FunctionDef | ast.AsyncFunctionDef + +HTTP_STATUS_NAMES = {status.name: status.value for status in HTTPStatus} +HTTP_STATUS_NAMES.update({f"HTTP_{status.value}_{status.name}": status.value for status in HTTPStatus}) + + +@dataclass(frozen=True) +class DocumentedResponse: + status: int + model: str | None + line: int + + +@dataclass(frozen=True) +class ActualResponse: + status: int | None + kind: ActualKind + model: str | None + line: int + + +@dataclass(frozen=True) +class ContractCheck: + classification: Classification + file: str + class_name: str + method: str + route: str + line: int + reason: str + documented: dict[int, str | None] + actual: list[ActualResponse] + + +@dataclass(frozen=True) +class ContractCheckContext: + """Stable route-method context shared by every classification result.""" + + file: str + class_name: str + method: str + route: str + line: int + documented: dict[int, str | None] + + def build( + self, classification: Classification, reason: str, actual_responses: Sequence[ActualResponse] + ) -> ContractCheck: + return ContractCheck( + classification=classification, + file=self.file, + class_name=self.class_name, + method=self.method, + route=self.route, + line=self.line, + reason=reason, + documented=self.documented, + actual=list(actual_responses), + ) + + def mismatch(self, reason: str, documented: DocumentedResponse, actual: ActualResponse) -> ContractCheck: + return self.build("mismatch", f"{reason} (doc line {documented.line}, return line {actual.line})", [actual]) + + +@dataclass +class VariableAssignmentSummary: + """Track whether a local name is safe to treat as one specific response model.""" + + known_models: set[str] = field(default_factory=set) + has_unknown_assignment: bool = False + + def add_known(self, model: str) -> None: + self.known_models.add(model) + + def add_unknown(self) -> None: + self.has_unknown_assignment = True + + def single_known_model(self) -> str | None: + if self.has_unknown_assignment or len(self.known_models) != 1: + return None + return next(iter(self.known_models)) + + +def dotted_name(node: ast.AST) -> str | None: + match node: + case ast.Name(): + return node.id + case ast.Attribute(): + parent = dotted_name(node.value) + if parent: + return f"{parent}.{node.attr}" + return node.attr + return None + + +def leaf_name(node: ast.AST) -> str | None: + name = dotted_name(node) + if name is None: + return None + return name.rsplit(".", 1)[-1] + + +def int_constant(node: ast.AST | None) -> int | None: + if isinstance(node, ast.Constant) and isinstance(node.value, int): + return node.value + if isinstance(node, ast.Name): + return HTTP_STATUS_NAMES.get(node.id) + if isinstance(node, ast.Attribute): + return HTTP_STATUS_NAMES.get(node.attr) + return None + + +def string_constant(node: ast.AST | None) -> str | None: + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return node.value + return None + + +def keyword_value(call: ast.Call, *names: str) -> ast.AST | None: + for keyword in call.keywords: + if keyword.arg in names: + return keyword.value + return None + + +def is_probable_model_name(name: str) -> bool: + return bool(name) and name[0].isupper() + + +def model_name_from_schema_expr(node: ast.AST | None) -> str | None: + if node is None: + return None + + if isinstance(node, ast.Subscript): + value_name = dotted_name(node.value) + if value_name and value_name.endswith(".models"): + # register_response_schema_models stores schemas by model name; both + # ns.models[Model.__name__] and ns.models["Model"] appear in controllers. + key = node.slice + if isinstance(key, ast.Attribute) and key.attr == "__name__": + return leaf_name(key.value) + return string_constant(key) + + if isinstance(node, ast.Call): + func_name = dotted_name(node.func) + if func_name and func_name.endswith(".model"): + return string_constant(node.args[0] if node.args else keyword_value(node, "name")) + + if isinstance(node, ast.Name): + return node.id + + return None + + +def documented_response_from_decorator(decorator: ast.expr) -> DocumentedResponse | None: + if not isinstance(decorator, ast.Call): + return None + + func_name = dotted_name(decorator.func) + if not func_name or not func_name.endswith(".response"): + return None + + status_expr = decorator.args[0] if decorator.args else keyword_value(decorator, "code", "status") + status = int_constant(status_expr) + if status is None: + return None + + schema_expr: ast.AST | None = decorator.args[2] if len(decorator.args) >= 3 else None + schema_expr = keyword_value(decorator, "model", "schema") or schema_expr + + return DocumentedResponse( + status=status, + model=model_name_from_schema_expr(schema_expr), + line=decorator.lineno, + ) + + +def route_from_decorator(decorator: ast.expr) -> str | None: + if not isinstance(decorator, ast.Call): + return None + + func_name = dotted_name(decorator.func) + if not func_name or not func_name.endswith(".route"): + return None + + return string_constant(decorator.args[0] if decorator.args else keyword_value(decorator, "route", "path")) + + +def routes_from_decorators(decorators: Iterable[ast.expr]) -> list[str]: + return [route for decorator in decorators if (route := route_from_decorator(decorator))] + + +def response_docs_from_decorators(decorators: Iterable[ast.expr]) -> dict[int, DocumentedResponse]: + docs: dict[int, DocumentedResponse] = {} + for decorator in decorators: + response = documented_response_from_decorator(decorator) + if response and 200 <= response.status < 300: + docs[response.status] = response + return docs + + +def model_name_from_model_validate_call(node: ast.AST) -> str | None: + if not isinstance(node, ast.Call): + return None + if isinstance(node.func, ast.Attribute) and node.func.attr == "model_validate": + return leaf_name(node.func.value) + return None + + +def model_name_from_constructor_call(node: ast.AST) -> str | None: + if not isinstance(node, ast.Call): + return None + if isinstance(node.func, ast.Name) and is_probable_model_name(node.func.id): + return node.func.id + return None + + +def model_name_from_model_dump(node: ast.AST) -> str | None: + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute) or node.func.attr != "model_dump": + return None + + dumped_value = node.func.value + if isinstance(dumped_value, ast.Call): + return model_name_from_model_validate_call(dumped_value) or model_name_from_constructor_call(dumped_value) + + return None + + +def model_name_from_model_value(node: ast.AST) -> str | None: + return model_name_from_model_validate_call(node) or model_name_from_constructor_call(node) + + +def model_name_from_dump_response(node: ast.AST) -> str | None: + if not isinstance(node, ast.Call): + return None + + func_name = dotted_name(node.func) + if func_name != "dump_response" and not (func_name and func_name.endswith(".dump_response")): + return None + + model_expr = node.args[0] if node.args else keyword_value(node, "model", "schema", "response_model") + if isinstance(model_expr, ast.Name): + return model_expr.id + return None + + +def actual_kind_from_expr( + expr: ast.AST | None, variable_models: dict[str, str] | None = None +) -> tuple[ActualKind, str | None]: + if expr is None: + return "none", None + + dump_response_model = model_name_from_dump_response(expr) + if dump_response_model: + return "model", dump_response_model + + if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "model_dump": + dumped_value = expr.func.value + if isinstance(dumped_value, ast.Name) and variable_models: + # A variable dump can match today, but it bypasses dump_response and + # is easier to drift; keep it visible as refactorable. + model_name = variable_models.get(dumped_value.id) + if model_name: + return "model_dump_variable", model_name + + model_dump_model = model_name_from_model_dump(expr) + if model_dump_model: + return "model", model_dump_model + + if isinstance(expr, ast.Constant): + if expr.value is None: + return "none", None + if expr.value == "": + return "empty", None + return "raw_value", None + + if isinstance(expr, ast.Dict): + return "raw_dict", None + + if isinstance(expr, ast.List): + return "raw_list", None + + return "unknown", None + + +def actual_response_from_return(return_node: ast.Return, variable_models: dict[str, str]) -> ActualResponse: + status: int | None = 200 + body_expr = return_node.value + + if isinstance(return_node.value, ast.Tuple) and return_node.value.elts: + body_expr = return_node.value.elts[0] + if len(return_node.value.elts) >= 2: + # Dynamic statuses are not safe to coerce to 200; classify them as unknown. + status = int_constant(return_node.value.elts[1]) + + kind, model = actual_kind_from_expr(body_expr, variable_models) + return ActualResponse(status=status, kind=kind, model=model, line=return_node.lineno) + + +def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]: + """Yield method body nodes while ignoring nested function/class scopes.""" + + stack: list[ast.AST] = list(reversed(method.body)) + while stack: + node = stack.pop() + yield node + + if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda | ast.ClassDef): + continue + + stack.extend(reversed(list(ast.iter_child_nodes(node)))) + + +def target_names(target: ast.AST) -> Iterable[str]: + if isinstance(target, ast.Name): + yield target.id + elif isinstance(target, ast.Tuple | ast.List): + for item in target.elts: + yield from target_names(item) + + +def record_assignment( + assignments: defaultdict[str, VariableAssignmentSummary], targets: Iterable[str], model_name: str | None +) -> None: + for target in targets: + if model_name is None: + # Once a name receives an unknown value, later model_dump() calls on it + # are no longer a reliable signal for the returned schema. + assignments[target].add_unknown() + else: + assignments[target].add_known(model_name) + + +def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]: + """Infer local variables that are unambiguously assigned one response model.""" + + assignments: defaultdict[str, VariableAssignmentSummary] = defaultdict(VariableAssignmentSummary) + + for node in iter_method_nodes(method): + match node: + case ast.Assign(targets=targets, value=value): + record_assignment( + assignments, + (name for target in targets for name in target_names(target)), + model_name_from_model_value(value), + ) + case ast.AnnAssign(target=target, value=value) if value is not None: + record_assignment(assignments, target_names(target), model_name_from_model_value(value)) + case ast.AugAssign(target=target) | ast.For(target=target) | ast.AsyncFor(target=target): + # Mutation and loop targets overwrite prior values with runtime-dependent data. + record_assignment(assignments, target_names(target), None) + case ast.With(items=items) | ast.AsyncWith(items=items): + for item in items: + if item.optional_vars is not None: + record_assignment(assignments, target_names(item.optional_vars), None) + case ast.ExceptHandler(name=name) if name: + assignments[name].add_unknown() + case ast.NamedExpr(target=target, value=value): + record_assignment(assignments, target_names(target), model_name_from_model_value(value)) + + return {name: model for name, summary in assignments.items() if (model := summary.single_known_model()) is not None} + + +def actual_responses_for_method(method: MethodNode) -> list[ActualResponse]: + """Extract statically visible 2xx returns from one controller method. + + The analysis is deliberately shallow and conservative: + + 1. Walk only the method's own body, skipping nested functions/classes. + 2. Infer local variables that are assigned exactly one recognizable response + model, so ``response.model_dump()`` can still be connected to its schema. + 3. Treat any later unknown assignment, mutation target, loop target, context + manager binding, or exception binding as invalidating that variable. + 4. For each top-level return path, split Flask-style ``(body, status)`` + tuples, classify the body expression, and keep non-literal statuses as + ``None`` so the classifier reports them as unknown instead of assuming 200. + 5. Drop non-2xx literal statuses, since response contracts here only compare + successful response schemas. + """ + + variable_models = variable_model_assignments_for_method(method) + responses: list[ActualResponse] = [] + for node in iter_method_nodes(method): + if isinstance(node, ast.Return): + responses.append(actual_response_from_return(node, variable_models)) + return [response for response in responses if response.status is None or 200 <= response.status < 300] + + +def display_path(file_path: Path, repo_root: Path) -> str: + try: + return str(file_path.relative_to(repo_root)) + except ValueError: + return str(file_path) + + +def classify_method( + *, + actual_responses: Sequence[ActualResponse], + class_name: str, + documented_responses: dict[int, DocumentedResponse], + file_path: Path, + method: MethodNode, + repo_root: Path, + route: str, +) -> ContractCheck: + documented_summary = {status: response.model for status, response in sorted(documented_responses.items())} + context = ContractCheckContext( + file=display_path(file_path, repo_root), + class_name=class_name, + method=method.name, + route=route, + line=method.lineno, + documented=documented_summary, + ) + + if not actual_responses: + return context.build("unknown", "no statically visible 2xx return", []) + + unknown_reasons: list[str] = [] + refactorable_reasons: list[str] = [] + + for actual in actual_responses: + if actual.status is None: + unknown_reasons.append(f"return line {actual.line} has non-literal or unsupported status") + continue + + documented = documented_responses.get(actual.status) + + if actual.status in NO_BODY_STATUSES: + # No-body statuses are contract violations even when the schema names + # would otherwise match, because clients should not expect a payload. + if documented is not None and documented.model is not None: + return context.mismatch( + f"status {actual.status} is a no-body response but documents {documented.model}", + documented, + actual, + ) + if actual.kind in {"model", "model_dump_variable", "raw_dict", "raw_list", "raw_value"}: + no_body_doc = DocumentedResponse(status=actual.status, model=None, line=method.lineno) + return context.mismatch( + f"status {actual.status} is a no-body response but returns {actual.kind}", + no_body_doc, + actual, + ) + if actual.kind == "unknown": + unknown_reasons.append(f"status {actual.status} returns unknown body expression") + continue + + if documented is None: + unknown_reasons.append(f"status {actual.status} has no @response doc") + continue + + if documented.model is None: + unknown_reasons.append(f"status {actual.status} response doc has no schema model") + continue + + if actual.kind == "model_dump_variable" and actual.model is not None: + if documented.model != actual.model: + return context.mismatch( + f"status {actual.status} documents {documented.model} but returns {actual.model}", + documented, + actual, + ) + # The schema matches, but this path still deserves cleanup because + # dump_response is the contract-aware serialization helper. + refactorable_reasons.append( + f"status {actual.status} returns {actual.model}.model_dump() through a variable; prefer dump_response" + ) + continue + + if actual.kind != "model" or actual.model is None: + unknown_reasons.append(f"status {actual.status} returns {actual.kind}") + continue + + if documented.model != actual.model: + return context.mismatch( + f"status {actual.status} documents {documented.model} but returns {actual.model}", + documented, + actual, + ) + + if unknown_reasons: + # Unknown beats refactorable: if any return path is ambiguous, do not + # imply the endpoint is merely a cleanup candidate. + return context.build("unknown", "; ".join(sorted(set(unknown_reasons))), actual_responses) + + if refactorable_reasons: + return context.build("refactorable", "; ".join(sorted(set(refactorable_reasons))), actual_responses) + + return context.build( + "valid", + "documented response schema matches statically visible return schema", + actual_responses, + ) + + +def iter_controller_files(paths: Iterable[Path]) -> Iterable[Path]: + for path in paths: + if path.is_file() and path.suffix == ".py": + yield path + elif path.is_dir(): + yield from sorted(child for child in path.rglob("*.py") if child.is_file()) + + +def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]: + module = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path)) + checks: list[ContractCheck] = [] + + for node in module.body: + if not isinstance(node, ast.ClassDef): + continue + + class_routes = routes_from_decorators(node.decorator_list) + class_documented = response_docs_from_decorators(node.decorator_list) + + for item in node.body: + if not isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) or item.name not in HTTP_METHODS: + continue + + routes = routes_from_decorators(item.decorator_list) or class_routes + if not routes: + continue + + documented = {**class_documented, **response_docs_from_decorators(item.decorator_list)} + # Method-level @response decorators override class-level defaults for + # the same status code, matching Flask-RESTX's common controller style. + actual = actual_responses_for_method(item) + for route in routes: + checks.append( + classify_method( + actual_responses=actual, + class_name=node.name, + documented_responses=documented, + file_path=file_path, + method=item, + repo_root=repo_root, + route=route, + ) + ) + + return checks + + +def as_jsonable(check: ContractCheck) -> dict[str, Any]: + data = asdict(check) + data["documented"] = {str(status): model for status, model in check.documented.items()} + return data + + +def print_text_report(checks: Sequence[ContractCheck], *, include_valid: bool) -> None: + counts = Counter(check.classification for check in checks) + sys.stdout.write( + "Response contract lint: " + f"{counts['valid']} valid, " + f"{counts['mismatch']} mismatch, " + f"{counts['refactorable']} refactorable, " + f"{counts['unknown']} unknown\n" + ) + + for classification in ("mismatch", "refactorable", "unknown", "valid"): + filtered = [check for check in checks if check.classification == classification] + if classification == "valid" and not include_valid: + continue + if not filtered: + continue + + sys.stdout.write(f"\n{classification.upper()}:\n") + for check in filtered: + location = f"{check.file}:{check.line} {check.class_name}.{check.method.upper()} {check.route}" + sys.stdout.write(f"- {location}: {check.reason}\n") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "paths", + nargs="*", + help="Files or directories to lint. Defaults to Flask controller directories.", + ) + parser.add_argument("--include-valid", action="store_true", help="Print valid route methods in text output.") + parser.add_argument("--json", action="store_true", help="Emit machine-readable JSON.") + parser.add_argument( + "--fail-on-mismatch", + action="store_true", + help="Treat mismatched response contracts as failures. By default this linter is report-only.", + ) + parser.add_argument( + "--fail-on-unknown", + action="store_true", + help="Treat unknown route methods as failures. By default this linter is report-only.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + api_root = Path(__file__).resolve().parents[1] + repo_root = api_root.parent + raw_paths = args.paths or list(DEFAULT_CONTROLLER_DIRS) + paths = [path if path.is_absolute() else api_root / path for path in map(Path, raw_paths)] + + checks: list[ContractCheck] = [] + for file_path in iter_controller_files(paths): + checks.extend(checks_for_file(file_path.resolve(), repo_root)) + + checks.sort(key=lambda check: (check.classification, check.file, check.line, check.method)) + + if args.json: + grouped = defaultdict(list) + for check in checks: + grouped[check.classification].append(as_jsonable(check)) + sys.stdout.write(f"{json.dumps(grouped, indent=2, sort_keys=True)}\n") + else: + print_text_report(checks, include_valid=bool(args.include_valid)) + + has_mismatch = any(check.classification == "mismatch" for check in checks) + has_unknown = any(check.classification == "unknown" for check in checks) + return int((bool(args.fail_on_mismatch) and has_mismatch) or (bool(args.fail_on_unknown) and has_unknown)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/api/tests/unit_tests/commands/test_lint_response_contracts.py b/api/tests/unit_tests/commands/test_lint_response_contracts.py new file mode 100644 index 0000000000..8f3860f231 --- /dev/null +++ b/api/tests/unit_tests/commands/test_lint_response_contracts.py @@ -0,0 +1,191 @@ +import importlib.util +import sys +from pathlib import Path + + +def _load_lint_response_contracts_module(): + api_dir = Path(__file__).parents[3] + script_path = api_dir / "dev" / "lint_response_contracts.py" + spec = importlib.util.spec_from_file_location("lint_response_contracts", script_path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _checks_for_source(tmp_path: Path, source: str): + module = _load_lint_response_contracts_module() + controller_path = tmp_path / "controllers" / "sample.py" + controller_path.parent.mkdir() + controller_path.write_text(source, encoding="utf-8") + return module.checks_for_file(controller_path, tmp_path) + + +def test_no_body_status_with_body_is_mismatch_while_empty_body_is_valid(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/bad") +class BadDeleteApi(Resource): + @ns.response(204, "Deleted") + def delete(self): + return {"result": "success"}, 204 + + +@ns.route("/ok") +class EmptyDeleteApi(Resource): + @ns.response(204, "Deleted") + def delete(self): + return "", 204 +""", + ) + + assert [(check.class_name, check.classification) for check in checks] == [ + ("BadDeleteApi", "mismatch"), + ("EmptyDeleteApi", "valid"), + ] + assert "no-body response but returns raw_dict" in checks[0].reason + + +def test_variable_model_dump_is_refactorable_not_valid(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +from http import HTTPStatus + + +@ns.route("/annotations") +class AnnotationApi(Resource): + @ns.response(HTTPStatus.CREATED, "Created", ns.models[AnnotationResponse.__name__]) + def post(self): + if use_existing: + response = AnnotationResponse.model_validate(existing, from_attributes=True) + else: + response = AnnotationResponse(id="new") + return response.model_dump(mode="json"), HTTPStatus.CREATED +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "refactorable" + assert checks[0].actual[0].status == 201 + assert checks[0].actual[0].kind == "model_dump_variable" + assert "prefer dump_response" in checks[0].reason + + +def test_variable_model_dump_with_wrong_documented_schema_is_mismatch(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/annotations") +class AnnotationApi(Resource): + @ns.response(200, "OK", ns.models[DocumentedResponse.__name__]) + def get(self): + response = ActualResponse.model_validate(data) + return response.model_dump(mode="json"), 200 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "mismatch" + assert "documents DocumentedResponse but returns ActualResponse" in checks[0].reason + + +def test_nested_returns_are_ignored_for_outer_control_flow(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/stream") +class StreamApi(Resource): + @ns.response(200, "OK", ns.models[StreamResponse.__name__]) + def get(self): + def generate_events(): + return dump_response(WrongResponse, {"event": "nested"}), 200 + + if finished: + return dump_response(StreamResponse, {"event": "done"}), 200 + return dump_response(StreamResponse, {"event": "running"}), 200 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "valid" + assert {actual.model for actual in checks[0].actual} == {"StreamResponse"} + + +def test_main_is_report_only_by_default_for_mismatches(tmp_path: Path, monkeypatch): + module = _load_lint_response_contracts_module() + controller_path = tmp_path / "controllers" / "sample.py" + controller_path.parent.mkdir() + controller_path.write_text( + """ +@ns.route("/bad") +class BadDeleteApi(Resource): + @ns.response(204, "Deleted") + def delete(self): + return {"result": "success"}, 204 +""", + encoding="utf-8", + ) + + monkeypatch.setattr(sys, "argv", ["lint_response_contracts.py", str(controller_path)]) + assert module.main() == 0 + + monkeypatch.setattr(sys, "argv", ["lint_response_contracts.py", "--fail-on-mismatch", str(controller_path)]) + assert module.main() == 1 + + +def test_class_level_route_and_response_docs_apply_to_methods(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route(path="/items") +@ns.response(code=200, description="OK", model=ns.models[ItemListResponse.__name__]) +class ItemListApi(Resource): + def get(self): + return dump_response(ItemListResponse, {"data": []}), 200 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "valid" + assert checks[0].route == "/items" + + +def test_unknown_reassignment_prevents_variable_model_dump_inference(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/items") +class ItemApi(Resource): + @ns.response(200, "OK", ns.models[ItemResponse.__name__]) + def get(self): + response = ItemResponse.model_validate(item) + if refresh: + response = load_response() + return response.model_dump(mode="json"), 200 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "unknown" + assert "returns unknown" in checks[0].reason + + +def test_non_literal_status_is_unknown_not_defaulted_to_200(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/items") +class ItemApi(Resource): + @ns.response(200, "OK", ns.models[ItemResponse.__name__]) + def get(self): + return dump_response(ItemResponse, item), status_code +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "unknown" + assert "non-literal or unsupported status" in checks[0].reason