mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
feat(api): Flask-RESTX response() vs actual return value checker (#36488)
This commit is contained in:
parent
092c8bca81
commit
f19702f76c
11
Makefile
11
Makefile
@ -75,13 +75,19 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
|
||||
@echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..."
|
||||
@uv run --project api --dev ruff format ./api
|
||||
@uv run --project api --dev ruff check --fix ./api
|
||||
@$(MAKE) api-contract-lint
|
||||
@uv run --directory api --dev lint-imports
|
||||
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
api-contract-lint:
|
||||
@echo "🔎 Linting Flask response contracts..."
|
||||
@uv run --project api --dev python api/dev/lint_response_contracts.py
|
||||
@echo "✅ Response contract lint complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@ -191,6 +197,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make api-contract-lint - Check Flask response docs against returned schemas"
|
||||
@echo " make type-check - Run type checks (pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@ -204,4 +211,4 @@ help:
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all
|
||||
|
||||
@ -195,6 +195,7 @@ Before opening a PR / submitting:
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Document non-obvious behaviour with concise docstrings and comments.
|
||||
- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload.
|
||||
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
|
||||
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
|
||||
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,
|
||||
|
||||
664
api/dev/lint_response_contracts.py
Normal file
664
api/dev/lint_response_contracts.py
Normal file
@ -0,0 +1,664 @@
|
||||
"""Lint Flask-RESTX response docs against statically visible response serializers.
|
||||
|
||||
This checker intentionally stays conservative. It only reports a hard schema
|
||||
mismatch when both sides are statically known for the same 2xx status code:
|
||||
a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)``
|
||||
or ``Model.model_validate(...).model_dump()`` return.
|
||||
|
||||
Raw dictionaries, raw lists, ``None`` responses, streaming helpers, missing
|
||||
response schemas, and returns with non-literal status codes are classified as
|
||||
unknown so reviewers can triage them without blocking unrelated work. The one
|
||||
intentional non-schema mismatch is a known body/schema on a no-body status such
|
||||
as 204, 205, or 304.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
HTTP_METHODS = {"delete", "get", "head", "options", "patch", "post", "put"}
|
||||
NO_BODY_STATUSES = {HTTPStatus.NO_CONTENT.value, HTTPStatus.RESET_CONTENT.value, HTTPStatus.NOT_MODIFIED.value}
|
||||
DEFAULT_CONTROLLER_DIRS = ("controllers/console", "controllers/service_api", "controllers/web")
|
||||
|
||||
type Classification = Literal["valid", "mismatch", "unknown", "refactorable"]
|
||||
type ActualKind = Literal[
|
||||
"empty",
|
||||
"model",
|
||||
"model_dump_variable",
|
||||
"none",
|
||||
"raw_dict",
|
||||
"raw_list",
|
||||
"raw_value",
|
||||
"unknown",
|
||||
]
|
||||
type MethodNode = ast.FunctionDef | ast.AsyncFunctionDef
|
||||
|
||||
HTTP_STATUS_NAMES = {status.name: status.value for status in HTTPStatus}
|
||||
HTTP_STATUS_NAMES.update({f"HTTP_{status.value}_{status.name}": status.value for status in HTTPStatus})
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentedResponse:
|
||||
status: int
|
||||
model: str | None
|
||||
line: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ActualResponse:
|
||||
status: int | None
|
||||
kind: ActualKind
|
||||
model: str | None
|
||||
line: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContractCheck:
|
||||
classification: Classification
|
||||
file: str
|
||||
class_name: str
|
||||
method: str
|
||||
route: str
|
||||
line: int
|
||||
reason: str
|
||||
documented: dict[int, str | None]
|
||||
actual: list[ActualResponse]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContractCheckContext:
|
||||
"""Stable route-method context shared by every classification result."""
|
||||
|
||||
file: str
|
||||
class_name: str
|
||||
method: str
|
||||
route: str
|
||||
line: int
|
||||
documented: dict[int, str | None]
|
||||
|
||||
def build(
|
||||
self, classification: Classification, reason: str, actual_responses: Sequence[ActualResponse]
|
||||
) -> ContractCheck:
|
||||
return ContractCheck(
|
||||
classification=classification,
|
||||
file=self.file,
|
||||
class_name=self.class_name,
|
||||
method=self.method,
|
||||
route=self.route,
|
||||
line=self.line,
|
||||
reason=reason,
|
||||
documented=self.documented,
|
||||
actual=list(actual_responses),
|
||||
)
|
||||
|
||||
def mismatch(self, reason: str, documented: DocumentedResponse, actual: ActualResponse) -> ContractCheck:
|
||||
return self.build("mismatch", f"{reason} (doc line {documented.line}, return line {actual.line})", [actual])
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariableAssignmentSummary:
|
||||
"""Track whether a local name is safe to treat as one specific response model."""
|
||||
|
||||
known_models: set[str] = field(default_factory=set)
|
||||
has_unknown_assignment: bool = False
|
||||
|
||||
def add_known(self, model: str) -> None:
|
||||
self.known_models.add(model)
|
||||
|
||||
def add_unknown(self) -> None:
|
||||
self.has_unknown_assignment = True
|
||||
|
||||
def single_known_model(self) -> str | None:
|
||||
if self.has_unknown_assignment or len(self.known_models) != 1:
|
||||
return None
|
||||
return next(iter(self.known_models))
|
||||
|
||||
|
||||
def dotted_name(node: ast.AST) -> str | None:
|
||||
match node:
|
||||
case ast.Name():
|
||||
return node.id
|
||||
case ast.Attribute():
|
||||
parent = dotted_name(node.value)
|
||||
if parent:
|
||||
return f"{parent}.{node.attr}"
|
||||
return node.attr
|
||||
return None
|
||||
|
||||
|
||||
def leaf_name(node: ast.AST) -> str | None:
|
||||
name = dotted_name(node)
|
||||
if name is None:
|
||||
return None
|
||||
return name.rsplit(".", 1)[-1]
|
||||
|
||||
|
||||
def int_constant(node: ast.AST | None) -> int | None:
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, int):
|
||||
return node.value
|
||||
if isinstance(node, ast.Name):
|
||||
return HTTP_STATUS_NAMES.get(node.id)
|
||||
if isinstance(node, ast.Attribute):
|
||||
return HTTP_STATUS_NAMES.get(node.attr)
|
||||
return None
|
||||
|
||||
|
||||
def string_constant(node: ast.AST | None) -> str | None:
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
return node.value
|
||||
return None
|
||||
|
||||
|
||||
def keyword_value(call: ast.Call, *names: str) -> ast.AST | None:
|
||||
for keyword in call.keywords:
|
||||
if keyword.arg in names:
|
||||
return keyword.value
|
||||
return None
|
||||
|
||||
|
||||
def is_probable_model_name(name: str) -> bool:
|
||||
return bool(name) and name[0].isupper()
|
||||
|
||||
|
||||
def model_name_from_schema_expr(node: ast.AST | None) -> str | None:
|
||||
if node is None:
|
||||
return None
|
||||
|
||||
if isinstance(node, ast.Subscript):
|
||||
value_name = dotted_name(node.value)
|
||||
if value_name and value_name.endswith(".models"):
|
||||
# register_response_schema_models stores schemas by model name; both
|
||||
# ns.models[Model.__name__] and ns.models["Model"] appear in controllers.
|
||||
key = node.slice
|
||||
if isinstance(key, ast.Attribute) and key.attr == "__name__":
|
||||
return leaf_name(key.value)
|
||||
return string_constant(key)
|
||||
|
||||
if isinstance(node, ast.Call):
|
||||
func_name = dotted_name(node.func)
|
||||
if func_name and func_name.endswith(".model"):
|
||||
return string_constant(node.args[0] if node.args else keyword_value(node, "name"))
|
||||
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def documented_response_from_decorator(decorator: ast.expr) -> DocumentedResponse | None:
|
||||
if not isinstance(decorator, ast.Call):
|
||||
return None
|
||||
|
||||
func_name = dotted_name(decorator.func)
|
||||
if not func_name or not func_name.endswith(".response"):
|
||||
return None
|
||||
|
||||
status_expr = decorator.args[0] if decorator.args else keyword_value(decorator, "code", "status")
|
||||
status = int_constant(status_expr)
|
||||
if status is None:
|
||||
return None
|
||||
|
||||
schema_expr: ast.AST | None = decorator.args[2] if len(decorator.args) >= 3 else None
|
||||
schema_expr = keyword_value(decorator, "model", "schema") or schema_expr
|
||||
|
||||
return DocumentedResponse(
|
||||
status=status,
|
||||
model=model_name_from_schema_expr(schema_expr),
|
||||
line=decorator.lineno,
|
||||
)
|
||||
|
||||
|
||||
def route_from_decorator(decorator: ast.expr) -> str | None:
|
||||
if not isinstance(decorator, ast.Call):
|
||||
return None
|
||||
|
||||
func_name = dotted_name(decorator.func)
|
||||
if not func_name or not func_name.endswith(".route"):
|
||||
return None
|
||||
|
||||
return string_constant(decorator.args[0] if decorator.args else keyword_value(decorator, "route", "path"))
|
||||
|
||||
|
||||
def routes_from_decorators(decorators: Iterable[ast.expr]) -> list[str]:
|
||||
return [route for decorator in decorators if (route := route_from_decorator(decorator))]
|
||||
|
||||
|
||||
def response_docs_from_decorators(decorators: Iterable[ast.expr]) -> dict[int, DocumentedResponse]:
|
||||
docs: dict[int, DocumentedResponse] = {}
|
||||
for decorator in decorators:
|
||||
response = documented_response_from_decorator(decorator)
|
||||
if response and 200 <= response.status < 300:
|
||||
docs[response.status] = response
|
||||
return docs
|
||||
|
||||
|
||||
def model_name_from_model_validate_call(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call):
|
||||
return None
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr == "model_validate":
|
||||
return leaf_name(node.func.value)
|
||||
return None
|
||||
|
||||
|
||||
def model_name_from_constructor_call(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call):
|
||||
return None
|
||||
if isinstance(node.func, ast.Name) and is_probable_model_name(node.func.id):
|
||||
return node.func.id
|
||||
return None
|
||||
|
||||
|
||||
def model_name_from_model_dump(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute) or node.func.attr != "model_dump":
|
||||
return None
|
||||
|
||||
dumped_value = node.func.value
|
||||
if isinstance(dumped_value, ast.Call):
|
||||
return model_name_from_model_validate_call(dumped_value) or model_name_from_constructor_call(dumped_value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def model_name_from_model_value(node: ast.AST) -> str | None:
|
||||
return model_name_from_model_validate_call(node) or model_name_from_constructor_call(node)
|
||||
|
||||
|
||||
def model_name_from_dump_response(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call):
|
||||
return None
|
||||
|
||||
func_name = dotted_name(node.func)
|
||||
if func_name != "dump_response" and not (func_name and func_name.endswith(".dump_response")):
|
||||
return None
|
||||
|
||||
model_expr = node.args[0] if node.args else keyword_value(node, "model", "schema", "response_model")
|
||||
if isinstance(model_expr, ast.Name):
|
||||
return model_expr.id
|
||||
return None
|
||||
|
||||
|
||||
def actual_kind_from_expr(
|
||||
expr: ast.AST | None, variable_models: dict[str, str] | None = None
|
||||
) -> tuple[ActualKind, str | None]:
|
||||
if expr is None:
|
||||
return "none", None
|
||||
|
||||
dump_response_model = model_name_from_dump_response(expr)
|
||||
if dump_response_model:
|
||||
return "model", dump_response_model
|
||||
|
||||
if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "model_dump":
|
||||
dumped_value = expr.func.value
|
||||
if isinstance(dumped_value, ast.Name) and variable_models:
|
||||
# A variable dump can match today, but it bypasses dump_response and
|
||||
# is easier to drift; keep it visible as refactorable.
|
||||
model_name = variable_models.get(dumped_value.id)
|
||||
if model_name:
|
||||
return "model_dump_variable", model_name
|
||||
|
||||
model_dump_model = model_name_from_model_dump(expr)
|
||||
if model_dump_model:
|
||||
return "model", model_dump_model
|
||||
|
||||
if isinstance(expr, ast.Constant):
|
||||
if expr.value is None:
|
||||
return "none", None
|
||||
if expr.value == "":
|
||||
return "empty", None
|
||||
return "raw_value", None
|
||||
|
||||
if isinstance(expr, ast.Dict):
|
||||
return "raw_dict", None
|
||||
|
||||
if isinstance(expr, ast.List):
|
||||
return "raw_list", None
|
||||
|
||||
return "unknown", None
|
||||
|
||||
|
||||
def actual_response_from_return(return_node: ast.Return, variable_models: dict[str, str]) -> ActualResponse:
|
||||
status: int | None = 200
|
||||
body_expr = return_node.value
|
||||
|
||||
if isinstance(return_node.value, ast.Tuple) and return_node.value.elts:
|
||||
body_expr = return_node.value.elts[0]
|
||||
if len(return_node.value.elts) >= 2:
|
||||
# Dynamic statuses are not safe to coerce to 200; classify them as unknown.
|
||||
status = int_constant(return_node.value.elts[1])
|
||||
|
||||
kind, model = actual_kind_from_expr(body_expr, variable_models)
|
||||
return ActualResponse(status=status, kind=kind, model=model, line=return_node.lineno)
|
||||
|
||||
|
||||
def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]:
|
||||
"""Yield method body nodes while ignoring nested function/class scopes."""
|
||||
|
||||
stack: list[ast.AST] = list(reversed(method.body))
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
|
||||
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda | ast.ClassDef):
|
||||
continue
|
||||
|
||||
stack.extend(reversed(list(ast.iter_child_nodes(node))))
|
||||
|
||||
|
||||
def target_names(target: ast.AST) -> Iterable[str]:
|
||||
if isinstance(target, ast.Name):
|
||||
yield target.id
|
||||
elif isinstance(target, ast.Tuple | ast.List):
|
||||
for item in target.elts:
|
||||
yield from target_names(item)
|
||||
|
||||
|
||||
def record_assignment(
|
||||
assignments: defaultdict[str, VariableAssignmentSummary], targets: Iterable[str], model_name: str | None
|
||||
) -> None:
|
||||
for target in targets:
|
||||
if model_name is None:
|
||||
# Once a name receives an unknown value, later model_dump() calls on it
|
||||
# are no longer a reliable signal for the returned schema.
|
||||
assignments[target].add_unknown()
|
||||
else:
|
||||
assignments[target].add_known(model_name)
|
||||
|
||||
|
||||
def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]:
|
||||
"""Infer local variables that are unambiguously assigned one response model."""
|
||||
|
||||
assignments: defaultdict[str, VariableAssignmentSummary] = defaultdict(VariableAssignmentSummary)
|
||||
|
||||
for node in iter_method_nodes(method):
|
||||
match node:
|
||||
case ast.Assign(targets=targets, value=value):
|
||||
record_assignment(
|
||||
assignments,
|
||||
(name for target in targets for name in target_names(target)),
|
||||
model_name_from_model_value(value),
|
||||
)
|
||||
case ast.AnnAssign(target=target, value=value) if value is not None:
|
||||
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
|
||||
case ast.AugAssign(target=target) | ast.For(target=target) | ast.AsyncFor(target=target):
|
||||
# Mutation and loop targets overwrite prior values with runtime-dependent data.
|
||||
record_assignment(assignments, target_names(target), None)
|
||||
case ast.With(items=items) | ast.AsyncWith(items=items):
|
||||
for item in items:
|
||||
if item.optional_vars is not None:
|
||||
record_assignment(assignments, target_names(item.optional_vars), None)
|
||||
case ast.ExceptHandler(name=name) if name:
|
||||
assignments[name].add_unknown()
|
||||
case ast.NamedExpr(target=target, value=value):
|
||||
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
|
||||
|
||||
return {name: model for name, summary in assignments.items() if (model := summary.single_known_model()) is not None}
|
||||
|
||||
|
||||
def actual_responses_for_method(method: MethodNode) -> list[ActualResponse]:
|
||||
"""Extract statically visible 2xx returns from one controller method.
|
||||
|
||||
The analysis is deliberately shallow and conservative:
|
||||
|
||||
1. Walk only the method's own body, skipping nested functions/classes.
|
||||
2. Infer local variables that are assigned exactly one recognizable response
|
||||
model, so ``response.model_dump()`` can still be connected to its schema.
|
||||
3. Treat any later unknown assignment, mutation target, loop target, context
|
||||
manager binding, or exception binding as invalidating that variable.
|
||||
4. For each top-level return path, split Flask-style ``(body, status)``
|
||||
tuples, classify the body expression, and keep non-literal statuses as
|
||||
``None`` so the classifier reports them as unknown instead of assuming 200.
|
||||
5. Drop non-2xx literal statuses, since response contracts here only compare
|
||||
successful response schemas.
|
||||
"""
|
||||
|
||||
variable_models = variable_model_assignments_for_method(method)
|
||||
responses: list[ActualResponse] = []
|
||||
for node in iter_method_nodes(method):
|
||||
if isinstance(node, ast.Return):
|
||||
responses.append(actual_response_from_return(node, variable_models))
|
||||
return [response for response in responses if response.status is None or 200 <= response.status < 300]
|
||||
|
||||
|
||||
def display_path(file_path: Path, repo_root: Path) -> str:
|
||||
try:
|
||||
return str(file_path.relative_to(repo_root))
|
||||
except ValueError:
|
||||
return str(file_path)
|
||||
|
||||
|
||||
def classify_method(
|
||||
*,
|
||||
actual_responses: Sequence[ActualResponse],
|
||||
class_name: str,
|
||||
documented_responses: dict[int, DocumentedResponse],
|
||||
file_path: Path,
|
||||
method: MethodNode,
|
||||
repo_root: Path,
|
||||
route: str,
|
||||
) -> ContractCheck:
|
||||
documented_summary = {status: response.model for status, response in sorted(documented_responses.items())}
|
||||
context = ContractCheckContext(
|
||||
file=display_path(file_path, repo_root),
|
||||
class_name=class_name,
|
||||
method=method.name,
|
||||
route=route,
|
||||
line=method.lineno,
|
||||
documented=documented_summary,
|
||||
)
|
||||
|
||||
if not actual_responses:
|
||||
return context.build("unknown", "no statically visible 2xx return", [])
|
||||
|
||||
unknown_reasons: list[str] = []
|
||||
refactorable_reasons: list[str] = []
|
||||
|
||||
for actual in actual_responses:
|
||||
if actual.status is None:
|
||||
unknown_reasons.append(f"return line {actual.line} has non-literal or unsupported status")
|
||||
continue
|
||||
|
||||
documented = documented_responses.get(actual.status)
|
||||
|
||||
if actual.status in NO_BODY_STATUSES:
|
||||
# No-body statuses are contract violations even when the schema names
|
||||
# would otherwise match, because clients should not expect a payload.
|
||||
if documented is not None and documented.model is not None:
|
||||
return context.mismatch(
|
||||
f"status {actual.status} is a no-body response but documents {documented.model}",
|
||||
documented,
|
||||
actual,
|
||||
)
|
||||
if actual.kind in {"model", "model_dump_variable", "raw_dict", "raw_list", "raw_value"}:
|
||||
no_body_doc = DocumentedResponse(status=actual.status, model=None, line=method.lineno)
|
||||
return context.mismatch(
|
||||
f"status {actual.status} is a no-body response but returns {actual.kind}",
|
||||
no_body_doc,
|
||||
actual,
|
||||
)
|
||||
if actual.kind == "unknown":
|
||||
unknown_reasons.append(f"status {actual.status} returns unknown body expression")
|
||||
continue
|
||||
|
||||
if documented is None:
|
||||
unknown_reasons.append(f"status {actual.status} has no @response doc")
|
||||
continue
|
||||
|
||||
if documented.model is None:
|
||||
unknown_reasons.append(f"status {actual.status} response doc has no schema model")
|
||||
continue
|
||||
|
||||
if actual.kind == "model_dump_variable" and actual.model is not None:
|
||||
if documented.model != actual.model:
|
||||
return context.mismatch(
|
||||
f"status {actual.status} documents {documented.model} but returns {actual.model}",
|
||||
documented,
|
||||
actual,
|
||||
)
|
||||
# The schema matches, but this path still deserves cleanup because
|
||||
# dump_response is the contract-aware serialization helper.
|
||||
refactorable_reasons.append(
|
||||
f"status {actual.status} returns {actual.model}.model_dump() through a variable; prefer dump_response"
|
||||
)
|
||||
continue
|
||||
|
||||
if actual.kind != "model" or actual.model is None:
|
||||
unknown_reasons.append(f"status {actual.status} returns {actual.kind}")
|
||||
continue
|
||||
|
||||
if documented.model != actual.model:
|
||||
return context.mismatch(
|
||||
f"status {actual.status} documents {documented.model} but returns {actual.model}",
|
||||
documented,
|
||||
actual,
|
||||
)
|
||||
|
||||
if unknown_reasons:
|
||||
# Unknown beats refactorable: if any return path is ambiguous, do not
|
||||
# imply the endpoint is merely a cleanup candidate.
|
||||
return context.build("unknown", "; ".join(sorted(set(unknown_reasons))), actual_responses)
|
||||
|
||||
if refactorable_reasons:
|
||||
return context.build("refactorable", "; ".join(sorted(set(refactorable_reasons))), actual_responses)
|
||||
|
||||
return context.build(
|
||||
"valid",
|
||||
"documented response schema matches statically visible return schema",
|
||||
actual_responses,
|
||||
)
|
||||
|
||||
|
||||
def iter_controller_files(paths: Iterable[Path]) -> Iterable[Path]:
|
||||
for path in paths:
|
||||
if path.is_file() and path.suffix == ".py":
|
||||
yield path
|
||||
elif path.is_dir():
|
||||
yield from sorted(child for child in path.rglob("*.py") if child.is_file())
|
||||
|
||||
|
||||
def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]:
|
||||
module = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path))
|
||||
checks: list[ContractCheck] = []
|
||||
|
||||
for node in module.body:
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
|
||||
class_routes = routes_from_decorators(node.decorator_list)
|
||||
class_documented = response_docs_from_decorators(node.decorator_list)
|
||||
|
||||
for item in node.body:
|
||||
if not isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) or item.name not in HTTP_METHODS:
|
||||
continue
|
||||
|
||||
routes = routes_from_decorators(item.decorator_list) or class_routes
|
||||
if not routes:
|
||||
continue
|
||||
|
||||
documented = {**class_documented, **response_docs_from_decorators(item.decorator_list)}
|
||||
# Method-level @response decorators override class-level defaults for
|
||||
# the same status code, matching Flask-RESTX's common controller style.
|
||||
actual = actual_responses_for_method(item)
|
||||
for route in routes:
|
||||
checks.append(
|
||||
classify_method(
|
||||
actual_responses=actual,
|
||||
class_name=node.name,
|
||||
documented_responses=documented,
|
||||
file_path=file_path,
|
||||
method=item,
|
||||
repo_root=repo_root,
|
||||
route=route,
|
||||
)
|
||||
)
|
||||
|
||||
return checks
|
||||
|
||||
|
||||
def as_jsonable(check: ContractCheck) -> dict[str, Any]:
|
||||
data = asdict(check)
|
||||
data["documented"] = {str(status): model for status, model in check.documented.items()}
|
||||
return data
|
||||
|
||||
|
||||
def print_text_report(checks: Sequence[ContractCheck], *, include_valid: bool) -> None:
|
||||
counts = Counter(check.classification for check in checks)
|
||||
sys.stdout.write(
|
||||
"Response contract lint: "
|
||||
f"{counts['valid']} valid, "
|
||||
f"{counts['mismatch']} mismatch, "
|
||||
f"{counts['refactorable']} refactorable, "
|
||||
f"{counts['unknown']} unknown\n"
|
||||
)
|
||||
|
||||
for classification in ("mismatch", "refactorable", "unknown", "valid"):
|
||||
filtered = [check for check in checks if check.classification == classification]
|
||||
if classification == "valid" and not include_valid:
|
||||
continue
|
||||
if not filtered:
|
||||
continue
|
||||
|
||||
sys.stdout.write(f"\n{classification.upper()}:\n")
|
||||
for check in filtered:
|
||||
location = f"{check.file}:{check.line} {check.class_name}.{check.method.upper()} {check.route}"
|
||||
sys.stdout.write(f"- {location}: {check.reason}\n")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"paths",
|
||||
nargs="*",
|
||||
help="Files or directories to lint. Defaults to Flask controller directories.",
|
||||
)
|
||||
parser.add_argument("--include-valid", action="store_true", help="Print valid route methods in text output.")
|
||||
parser.add_argument("--json", action="store_true", help="Emit machine-readable JSON.")
|
||||
parser.add_argument(
|
||||
"--fail-on-mismatch",
|
||||
action="store_true",
|
||||
help="Treat mismatched response contracts as failures. By default this linter is report-only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fail-on-unknown",
|
||||
action="store_true",
|
||||
help="Treat unknown route methods as failures. By default this linter is report-only.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
api_root = Path(__file__).resolve().parents[1]
|
||||
repo_root = api_root.parent
|
||||
raw_paths = args.paths or list(DEFAULT_CONTROLLER_DIRS)
|
||||
paths = [path if path.is_absolute() else api_root / path for path in map(Path, raw_paths)]
|
||||
|
||||
checks: list[ContractCheck] = []
|
||||
for file_path in iter_controller_files(paths):
|
||||
checks.extend(checks_for_file(file_path.resolve(), repo_root))
|
||||
|
||||
checks.sort(key=lambda check: (check.classification, check.file, check.line, check.method))
|
||||
|
||||
if args.json:
|
||||
grouped = defaultdict(list)
|
||||
for check in checks:
|
||||
grouped[check.classification].append(as_jsonable(check))
|
||||
sys.stdout.write(f"{json.dumps(grouped, indent=2, sort_keys=True)}\n")
|
||||
else:
|
||||
print_text_report(checks, include_valid=bool(args.include_valid))
|
||||
|
||||
has_mismatch = any(check.classification == "mismatch" for check in checks)
|
||||
has_unknown = any(check.classification == "unknown" for check in checks)
|
||||
return int((bool(args.fail_on_mismatch) and has_mismatch) or (bool(args.fail_on_unknown) and has_unknown))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
191
api/tests/unit_tests/commands/test_lint_response_contracts.py
Normal file
191
api/tests/unit_tests/commands/test_lint_response_contracts.py
Normal file
@ -0,0 +1,191 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_lint_response_contracts_module():
|
||||
api_dir = Path(__file__).parents[3]
|
||||
script_path = api_dir / "dev" / "lint_response_contracts.py"
|
||||
spec = importlib.util.spec_from_file_location("lint_response_contracts", script_path)
|
||||
assert spec is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _checks_for_source(tmp_path: Path, source: str):
|
||||
module = _load_lint_response_contracts_module()
|
||||
controller_path = tmp_path / "controllers" / "sample.py"
|
||||
controller_path.parent.mkdir()
|
||||
controller_path.write_text(source, encoding="utf-8")
|
||||
return module.checks_for_file(controller_path, tmp_path)
|
||||
|
||||
|
||||
def test_no_body_status_with_body_is_mismatch_while_empty_body_is_valid(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
@ns.route("/bad")
|
||||
class BadDeleteApi(Resource):
|
||||
@ns.response(204, "Deleted")
|
||||
def delete(self):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ns.route("/ok")
|
||||
class EmptyDeleteApi(Resource):
|
||||
@ns.response(204, "Deleted")
|
||||
def delete(self):
|
||||
return "", 204
|
||||
""",
|
||||
)
|
||||
|
||||
assert [(check.class_name, check.classification) for check in checks] == [
|
||||
("BadDeleteApi", "mismatch"),
|
||||
("EmptyDeleteApi", "valid"),
|
||||
]
|
||||
assert "no-body response but returns raw_dict" in checks[0].reason
|
||||
|
||||
|
||||
def test_variable_model_dump_is_refactorable_not_valid(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
@ns.route("/annotations")
|
||||
class AnnotationApi(Resource):
|
||||
@ns.response(HTTPStatus.CREATED, "Created", ns.models[AnnotationResponse.__name__])
|
||||
def post(self):
|
||||
if use_existing:
|
||||
response = AnnotationResponse.model_validate(existing, from_attributes=True)
|
||||
else:
|
||||
response = AnnotationResponse(id="new")
|
||||
return response.model_dump(mode="json"), HTTPStatus.CREATED
|
||||
""",
|
||||
)
|
||||
|
||||
assert len(checks) == 1
|
||||
assert checks[0].classification == "refactorable"
|
||||
assert checks[0].actual[0].status == 201
|
||||
assert checks[0].actual[0].kind == "model_dump_variable"
|
||||
assert "prefer dump_response" in checks[0].reason
|
||||
|
||||
|
||||
def test_variable_model_dump_with_wrong_documented_schema_is_mismatch(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
@ns.route("/annotations")
|
||||
class AnnotationApi(Resource):
|
||||
@ns.response(200, "OK", ns.models[DocumentedResponse.__name__])
|
||||
def get(self):
|
||||
response = ActualResponse.model_validate(data)
|
||||
return response.model_dump(mode="json"), 200
|
||||
""",
|
||||
)
|
||||
|
||||
assert len(checks) == 1
|
||||
assert checks[0].classification == "mismatch"
|
||||
assert "documents DocumentedResponse but returns ActualResponse" in checks[0].reason
|
||||
|
||||
|
||||
def test_nested_returns_are_ignored_for_outer_control_flow(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
@ns.route("/stream")
|
||||
class StreamApi(Resource):
|
||||
@ns.response(200, "OK", ns.models[StreamResponse.__name__])
|
||||
def get(self):
|
||||
def generate_events():
|
||||
return dump_response(WrongResponse, {"event": "nested"}), 200
|
||||
|
||||
if finished:
|
||||
return dump_response(StreamResponse, {"event": "done"}), 200
|
||||
return dump_response(StreamResponse, {"event": "running"}), 200
|
||||
""",
|
||||
)
|
||||
|
||||
assert len(checks) == 1
|
||||
assert checks[0].classification == "valid"
|
||||
assert {actual.model for actual in checks[0].actual} == {"StreamResponse"}
|
||||
|
||||
|
||||
def test_main_is_report_only_by_default_for_mismatches(tmp_path: Path, monkeypatch):
|
||||
module = _load_lint_response_contracts_module()
|
||||
controller_path = tmp_path / "controllers" / "sample.py"
|
||||
controller_path.parent.mkdir()
|
||||
controller_path.write_text(
|
||||
"""
|
||||
@ns.route("/bad")
|
||||
class BadDeleteApi(Resource):
|
||||
@ns.response(204, "Deleted")
|
||||
def delete(self):
|
||||
return {"result": "success"}, 204
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(sys, "argv", ["lint_response_contracts.py", str(controller_path)])
|
||||
assert module.main() == 0
|
||||
|
||||
monkeypatch.setattr(sys, "argv", ["lint_response_contracts.py", "--fail-on-mismatch", str(controller_path)])
|
||||
assert module.main() == 1
|
||||
|
||||
|
||||
def test_class_level_route_and_response_docs_apply_to_methods(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
@ns.route(path="/items")
|
||||
@ns.response(code=200, description="OK", model=ns.models[ItemListResponse.__name__])
|
||||
class ItemListApi(Resource):
|
||||
def get(self):
|
||||
return dump_response(ItemListResponse, {"data": []}), 200
|
||||
""",
|
||||
)
|
||||
|
||||
assert len(checks) == 1
|
||||
assert checks[0].classification == "valid"
|
||||
assert checks[0].route == "/items"
|
||||
|
||||
|
||||
def test_unknown_reassignment_prevents_variable_model_dump_inference(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
@ns.route("/items")
|
||||
class ItemApi(Resource):
|
||||
@ns.response(200, "OK", ns.models[ItemResponse.__name__])
|
||||
def get(self):
|
||||
response = ItemResponse.model_validate(item)
|
||||
if refresh:
|
||||
response = load_response()
|
||||
return response.model_dump(mode="json"), 200
|
||||
""",
|
||||
)
|
||||
|
||||
assert len(checks) == 1
|
||||
assert checks[0].classification == "unknown"
|
||||
assert "returns unknown" in checks[0].reason
|
||||
|
||||
|
||||
def test_non_literal_status_is_unknown_not_defaulted_to_200(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
@ns.route("/items")
|
||||
class ItemApi(Resource):
|
||||
@ns.response(200, "OK", ns.models[ItemResponse.__name__])
|
||||
def get(self):
|
||||
return dump_response(ItemResponse, item), status_code
|
||||
""",
|
||||
)
|
||||
|
||||
assert len(checks) == 1
|
||||
assert checks[0].classification == "unknown"
|
||||
assert "non-literal or unsupported status" in checks[0].reason
|
||||
Loading…
Reference in New Issue
Block a user