diff --git a/api/controllers/openapi/__init__.py b/api/controllers/openapi/__init__.py index c11019cf627..81c65ca03be 100644 --- a/api/controllers/openapi/__init__.py +++ b/api/controllers/openapi/__init__.py @@ -20,7 +20,7 @@ openapi_ns = Namespace("openapi", description="User-scoped operations", path="/" # Register response/query models BEFORE importing controller modules so that # @openapi_ns.response / @openapi_ns.expect decorators can resolve model names. -from controllers.common.fields import EventStreamResponse +from controllers.common.fields import EventStreamResponse, SimpleResultResponse from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models from controllers.openapi._models import ( AccountPayload, @@ -95,6 +95,7 @@ register_response_schema_models( openapi_ns, ErrorBody, EventStreamResponse, + SimpleResultResponse, UsageInfo, MessageMetadata, AppListRow, diff --git a/api/controllers/openapi/app_run.py b/api/controllers/openapi/app_run.py index 7e77e3aa747..a22534ae82c 100644 --- a/api/controllers/openapi/app_run.py +++ b/api/controllers/openapi/app_run.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from collections.abc import Callable, Iterator +from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Any @@ -61,7 +61,7 @@ logger = logging.getLogger(__name__) @contextmanager -def _translate_service_errors() -> Iterator[None]: +def _translate_service_errors() -> Generator[None, None, None]: try: yield except WorkflowNotFoundError as ex: @@ -166,6 +166,7 @@ class AppRunApi(Resource): surface="apps", ) + # response-contract:ignore compact_generate_response return helper.compact_generate_response(stream_obj) diff --git a/api/dev/lint_response_contracts.py b/api/dev/lint_response_contracts.py index 75c5f67b8ff..77a2d2dc818 100644 --- a/api/dev/lint_response_contracts.py +++ b/api/dev/lint_response_contracts.py @@ -2,8 +2,8 @@ This checker intentionally stays conservative. It only reports a hard schema mismatch when both sides are statically known for the same 2xx status code: -a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)`` -or ``Model.model_validate(...).model_dump()`` return. +a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)``, +``Model(...).model_dump()``, or ``Model.model_validate(...).model_dump()`` return. Raw dictionaries, raw lists, ``None`` responses, streaming helpers, missing response schemas, and returns with non-literal status codes are classified as @@ -28,6 +28,7 @@ from typing import Any, Literal HTTP_METHODS = {"delete", "get", "head", "options", "patch", "post", "put"} NO_BODY_STATUSES = {HTTPStatus.NO_CONTENT.value, HTTPStatus.RESET_CONTENT.value, HTTPStatus.NOT_MODIFIED.value} DEFAULT_CONTROLLER_DIRS = ("controllers/console", "controllers/service_api", "controllers/web") +IGNORE_COMMENT_MARKERS = ("response-contract:ignore",) type Classification = Literal["valid", "mismatch", "unknown", "refactorable"] type ActualKind = Literal[ @@ -41,6 +42,7 @@ type ActualKind = Literal[ "unknown", ] type MethodNode = ast.FunctionDef | ast.AsyncFunctionDef +type ModelValueSource = Literal["constructor", "model_validate"] HTTP_STATUS_NAMES = {status.name: status.value for status in HTTPStatus} HTTP_STATUS_NAMES.update({f"HTTP_{status.value}_{status.name}": status.value for status in HTTPStatus}) @@ -109,18 +111,22 @@ class VariableAssignmentSummary: """Track whether a local name is safe to treat as one specific response model.""" known_models: set[str] = field(default_factory=set) + known_sources: set[ModelValueSource] = field(default_factory=set) has_unknown_assignment: bool = False - def add_known(self, model: str) -> None: + def add_known(self, model: str, source: ModelValueSource) -> None: self.known_models.add(model) + self.known_sources.add(source) def add_unknown(self) -> None: self.has_unknown_assignment = True - def single_known_model(self) -> str | None: + def single_known_model(self) -> tuple[str, ModelValueSource] | None: if self.has_unknown_assignment or len(self.known_models) != 1: return None - return next(iter(self.known_models)) + model = next(iter(self.known_models)) + source: ModelValueSource = "constructor" if self.known_sources == {"constructor"} else "model_validate" + return model, source def dotted_name(node: ast.AST) -> str | None: @@ -249,6 +255,12 @@ def model_name_from_model_validate_call(node: ast.AST) -> str | None: return None +def model_value_from_model_validate_call(node: ast.AST) -> tuple[str, ModelValueSource] | None: + if model_name := model_name_from_model_validate_call(node): + return model_name, "model_validate" + return None + + def model_name_from_constructor_call(node: ast.AST) -> str | None: if not isinstance(node, ast.Call): return None @@ -257,6 +269,12 @@ def model_name_from_constructor_call(node: ast.AST) -> str | None: return None +def model_value_from_constructor_call(node: ast.AST) -> tuple[str, ModelValueSource] | None: + if model_name := model_name_from_constructor_call(node): + return model_name, "constructor" + return None + + def model_name_from_model_dump(node: ast.AST) -> str | None: if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute) or node.func.attr != "model_dump": return None @@ -272,6 +290,10 @@ def model_name_from_model_value(node: ast.AST) -> str | None: return model_name_from_model_validate_call(node) or model_name_from_constructor_call(node) +def model_value_from_model_value(node: ast.AST) -> tuple[str, ModelValueSource] | None: + return model_value_from_model_validate_call(node) or model_value_from_constructor_call(node) + + def model_name_from_dump_response(node: ast.AST) -> str | None: if not isinstance(node, ast.Call): return None @@ -287,7 +309,7 @@ def model_name_from_dump_response(node: ast.AST) -> str | None: def actual_kind_from_expr( - expr: ast.AST | None, variable_models: dict[str, str] | None = None + expr: ast.AST | None, variable_models: dict[str, tuple[str, ModelValueSource]] | None = None ) -> tuple[ActualKind, str | None]: if expr is None: return "none", None @@ -299,10 +321,14 @@ def actual_kind_from_expr( if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "model_dump": dumped_value = expr.func.value if isinstance(dumped_value, ast.Name) and variable_models: - # A variable dump can match today, but it bypasses dump_response and - # is easier to drift; keep it visible as refactorable. - model_name = variable_models.get(dumped_value.id) - if model_name: + model_assignment = variable_models.get(dumped_value.id) + if model_assignment: + model_name, source = model_assignment + if source == "constructor": + return "model", model_name + # A variable dump from model_validate can match today, but it + # bypasses dump_response and is easier to drift; keep it visible + # as refactorable. return "model_dump_variable", model_name model_dump_model = model_name_from_model_dump(expr) @@ -325,7 +351,9 @@ def actual_kind_from_expr( return "unknown", None -def actual_response_from_return(return_node: ast.Return, variable_models: dict[str, str]) -> ActualResponse: +def actual_response_from_return( + return_node: ast.Return, variable_models: dict[str, tuple[str, ModelValueSource]] +) -> ActualResponse: status: int | None = 200 body_expr = return_node.value @@ -363,18 +391,21 @@ def target_names(target: ast.AST) -> Iterable[str]: def record_assignment( - assignments: defaultdict[str, VariableAssignmentSummary], targets: Iterable[str], model_name: str | None + assignments: defaultdict[str, VariableAssignmentSummary], + targets: Iterable[str], + model_assignment: tuple[str, ModelValueSource] | None, ) -> None: for target in targets: - if model_name is None: + if model_assignment is None: # Once a name receives an unknown value, later model_dump() calls on it # are no longer a reliable signal for the returned schema. assignments[target].add_unknown() else: - assignments[target].add_known(model_name) + model_name, source = model_assignment + assignments[target].add_known(model_name, source) -def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]: +def variable_model_assignments_for_method(method: MethodNode) -> dict[str, tuple[str, ModelValueSource]]: """Infer local variables that are unambiguously assigned one response model.""" assignments: defaultdict[str, VariableAssignmentSummary] = defaultdict(VariableAssignmentSummary) @@ -385,10 +416,10 @@ def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]: record_assignment( assignments, (name for target in targets for name in target_names(target)), - model_name_from_model_value(value), + model_value_from_model_value(value), ) case ast.AnnAssign(target=target, value=value) if value is not None: - record_assignment(assignments, target_names(target), model_name_from_model_value(value)) + record_assignment(assignments, target_names(target), model_value_from_model_value(value)) case ast.AugAssign(target=target) | ast.For(target=target) | ast.AsyncFor(target=target): # Mutation and loop targets overwrite prior values with runtime-dependent data. record_assignment(assignments, target_names(target), None) @@ -399,9 +430,13 @@ def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]: case ast.ExceptHandler(name=name) if name: assignments[name].add_unknown() case ast.NamedExpr(target=target, value=value): - record_assignment(assignments, target_names(target), model_name_from_model_value(value)) + record_assignment(assignments, target_names(target), model_value_from_model_value(value)) - return {name: model for name, summary in assignments.items() if (model := summary.single_known_model()) is not None} + return { + name: assignment + for name, summary in assignments.items() + if (assignment := summary.single_known_model()) is not None + } def actual_responses_for_method(method: MethodNode) -> list[ActualResponse]: @@ -545,13 +580,52 @@ def iter_controller_files(paths: Iterable[Path]) -> Iterable[Path]: yield from sorted(child for child in path.rglob("*.py") if child.is_file()) +def node_start_lineno(node: ast.ClassDef | MethodNode) -> int: + decorator_lines = [decorator.lineno for decorator in node.decorator_list] + if decorator_lines: + return min(decorator_lines) + return node.lineno + + +def line_has_ignore_marker(line: str) -> bool: + _, marker, comment = line.partition("#") + if not marker: + return False + normalized = comment.lower() + return any(ignore_marker in normalized for ignore_marker in IGNORE_COMMENT_MARKERS) + + +def node_has_ignore_comment(lines: Sequence[str], node: ast.ClassDef | MethodNode) -> bool: + start = node_start_lineno(node) + end = node.end_lineno or node.lineno + if any(line_has_ignore_marker(line) for line in lines[start - 1 : end]): + return True + + line_index = start - 2 + while line_index >= 0: + stripped = lines[line_index].strip() + if not stripped: + line_index -= 1 + continue + if not stripped.startswith("#"): + break + if line_has_ignore_marker(lines[line_index]): + return True + line_index -= 1 + return False + + def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]: - module = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path)) + source = file_path.read_text(encoding="utf-8") + lines = source.splitlines() + module = ast.parse(source, filename=str(file_path)) checks: list[ContractCheck] = [] for node in module.body: if not isinstance(node, ast.ClassDef): continue + if node_has_ignore_comment(lines, node): + continue class_routes = routes_from_decorators(node.decorator_list) class_documented = response_docs_from_decorators(node.decorator_list) @@ -559,6 +633,8 @@ def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]: for item in node.body: if not isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) or item.name not in HTTP_METHODS: continue + if node_has_ignore_comment(lines, item): + continue routes = routes_from_decorators(item.decorator_list) or class_routes if not routes: diff --git a/api/fields/base.py b/api/fields/base.py index b806ab6c9c8..826a07d00ea 100644 --- a/api/fields/base.py +++ b/api/fields/base.py @@ -7,7 +7,8 @@ class ResponseModel(BaseModel): model_config = ConfigDict( from_attributes=True, extra="ignore", - populate_by_name=True, + validate_by_name=True, + validate_by_alias=True, serialize_by_alias=True, protected_namespaces=(), ) diff --git a/api/openapi/markdown/openapi-openapi.md b/api/openapi/markdown/openapi-openapi.md index 4bb6761c22e..08544f20a9d 100644 --- a/api/openapi/markdown/openapi-openapi.md +++ b/api/openapi/markdown/openapi-openapi.md @@ -990,6 +990,12 @@ Pagination for GET /account/sessions. Strict (extra='forbid'). | last_used_at | string | | No | | prefix | string | | Yes | +#### SimpleResultResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| result | string | | Yes | + #### SupportedAppType App types the ``app`` usage face (``get app``) lists and filters. diff --git a/api/tests/unit_tests/commands/test_lint_response_contracts.py b/api/tests/unit_tests/commands/test_lint_response_contracts.py index 351fdf0d923..68c4cbf966e 100644 --- a/api/tests/unit_tests/commands/test_lint_response_contracts.py +++ b/api/tests/unit_tests/commands/test_lint_response_contracts.py @@ -77,6 +77,25 @@ class AnnotationApi(Resource): assert "prefer dump_response" in checks[0].reason +def test_constructor_variable_model_dump_is_valid(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/annotations") +class AnnotationApi(Resource): + @ns.response(201, "Created", ns.models[AnnotationResponse.__name__]) + def post(self): + response = AnnotationResponse(id="new", name=name) + return response.model_dump(mode="json"), 201 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "valid" + assert checks[0].actual[0].kind == "model" + assert checks[0].actual[0].model == "AnnotationResponse" + + def test_variable_model_dump_with_wrong_documented_schema_is_mismatch(tmp_path: Path): checks = _checks_for_source( tmp_path, @@ -117,6 +136,38 @@ class StreamApi(Resource): assert {actual.model for actual in checks[0].actual} == {"StreamResponse"} +def test_response_contract_ignore_comment_skips_route_method(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/binary") +class BinaryApi(Resource): + # response-contract:ignore binary response + @ns.response(200, "Binary file") + def get(self): + return send_file(path) + + +# response-contract:ignore compact Flask response +@ns.route("/compact") +class CompactApi(Resource): + def get(self): + return make_response({"url": "https://example.com"}) + + +@ns.route("/regular") +class RegularApi(Resource): + @ns.response(200, "OK", ns.models[RegularResponse.__name__]) + def get(self): + return dump_response(RegularResponse, {}) +""", + ) + + assert len(checks) == 1 + assert checks[0].class_name == "RegularApi" + assert checks[0].classification == "valid" + + def test_main_is_report_only_by_default_for_mismatches(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): module = _load_lint_response_contracts_module() controller_path = tmp_path / "controllers" / "sample.py" diff --git a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py index ddd72f604d6..aabed01262c 100644 --- a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py +++ b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py @@ -3,13 +3,36 @@ from __future__ import annotations import sys -from types import SimpleNamespace +import uuid from unittest.mock import Mock import pytest from flask import Flask from controllers.openapi._models import AppRunRequest +from models import Account +from models.model import App, AppMode + +_TEST_APP_ID = str(uuid.uuid4()) +_TEST_TENANT_ID = str(uuid.uuid4()) +_TEST_ACCOUNT_ID = str(uuid.uuid4()) + + +def _make_app() -> App: + app = App() + app.id = _TEST_APP_ID + app.tenant_id = _TEST_TENANT_ID + app.name = "Streaming app" + app.mode = AppMode.CHAT + app.enable_site = False + app.enable_api = True + return app + + +def _make_account() -> Account: + account = Account(name="OpenAPI caller", email="caller@example.com") + account.id = _TEST_ACCOUNT_ID + return account def test_app_run_request_has_no_response_mode_field(): @@ -40,15 +63,19 @@ def test_run_chat_always_calls_generate_with_streaming_true( from controllers.openapi.app_run import _run_chat generate_mock = Mock(return_value=iter([])) + + class GenerateService: + generate = generate_mock + monkeypatch.setattr( sys.modules["controllers.openapi.app_run"], "AppGenerateService", - SimpleNamespace(generate=generate_mock), + GenerateService, ) - with app.test_request_context("/openapi/v1/apps/app-1/run", method="POST"): + with app.test_request_context(f"/openapi/v1/apps/{_TEST_APP_ID}/run", method="POST"): _run_chat( - SimpleNamespace(id="app-1", tenant_id="t-1"), - SimpleNamespace(id="acct-1"), + _make_app(), + _make_account(), AppRunRequest(inputs={}, query="hello"), ) _, kwargs = generate_mock.call_args @@ -80,11 +107,11 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app: Flask, bypass_pipel auth_data = AuthData.model_construct( token_type=TokenType.OAUTH_ACCOUNT, - account_id=uuid.uuid4(), + account_id=uuid.UUID(_TEST_ACCOUNT_ID), token_hash="test", scopes=frozenset({Scope.FULL}), - app=SimpleNamespace(id="app-1", tenant_id="t-1"), - caller=SimpleNamespace(id="acct-1"), + app=_make_app(), + caller=_make_account(), caller_kind="account", ) diff --git a/api/tests/unit_tests/controllers/openapi/test_contract.py b/api/tests/unit_tests/controllers/openapi/test_contract.py index b8773f56df8..69263a5d929 100644 --- a/api/tests/unit_tests/controllers/openapi/test_contract.py +++ b/api/tests/unit_tests/controllers/openapi/test_contract.py @@ -5,6 +5,7 @@ view function decorated with @accepts/@returns, driven inside a request context. """ from functools import wraps +from typing import Any, cast import pytest from pydantic import BaseModel, ConfigDict, Field @@ -100,7 +101,7 @@ def test_accepts_validation_error_is_sanitized_and_structured(app): with pytest.raises(UnprocessableEntity) as exc_info: view() - data = exc_info.value.data + data = cast(dict[str, Any], cast(Any, exc_info.value).data) assert data["message"] == "Request validation failed" assert isinstance(data["errors"], list) assert data["errors"] diff --git a/packages/contracts/generated/api/openapi/types.gen.ts b/packages/contracts/generated/api/openapi/types.gen.ts index 2d47f947247..622fb9c63f8 100644 --- a/packages/contracts/generated/api/openapi/types.gen.ts +++ b/packages/contracts/generated/api/openapi/types.gen.ts @@ -405,6 +405,10 @@ export type SessionRow = { prefix: string } +export type SimpleResultResponse = { + result: string +} + export type SupportedAppType = 'advanced-chat' | 'agent-chat' | 'chat' | 'completion' | 'workflow' export type TaskStopResponse = { diff --git a/packages/contracts/generated/api/openapi/zod.gen.ts b/packages/contracts/generated/api/openapi/zod.gen.ts index 557447cc769..446f47817ef 100644 --- a/packages/contracts/generated/api/openapi/zod.gen.ts +++ b/packages/contracts/generated/api/openapi/zod.gen.ts @@ -501,6 +501,13 @@ export const zSessionListResponse = z.object({ total: z.int(), }) +/** + * SimpleResultResponse + */ +export const zSimpleResultResponse = z.object({ + result: z.string(), +}) + /** * SupportedAppType * diff --git a/packages/contracts/openapi-ts.api.config.ts b/packages/contracts/openapi-ts.api.config.ts index 8fce8a25bd3..1adbf4fda8e 100644 --- a/packages/contracts/openapi-ts.api.config.ts +++ b/packages/contracts/openapi-ts.api.config.ts @@ -10,13 +10,21 @@ type SwaggerSchema = JsonObject & { $ref?: string } +type OpenApiMediaType = JsonObject & { + schema?: SwaggerSchema +} + +type OpenApiResponse = JsonObject & { + content?: Record +} + type OpenApiComponents = JsonObject & { schemas?: Record } type SwaggerOperation = JsonObject & { operationId?: string - responses?: Record + responses?: Record } type SwaggerDocument = JsonObject & { @@ -52,6 +60,17 @@ const currentDir = path.dirname(fileURLToPath(import.meta.url)) const apiOpenApiDir = path.resolve(currentDir, 'openapi') const operationMethods = new Set(['delete', 'get', 'patch', 'post', 'put']) +const pydanticDecimalStringPattern = '^(?!^[-+.]*$)[+-]?0*\\d*\\.?\\d*$' +const codegenSafeDecimalStringPattern = '^(?![-+.]*$)[+-]?0*\\d*\\.?\\d*$' + +const opaqueJsonContent = (): Record => ({ + 'application/json': { + schema: { + additionalProperties: true, + type: 'object', + }, + }, +}) const apiSpecs: ApiSpec[] = [ { filename: 'console-openapi.json', name: 'console' }, @@ -182,6 +201,46 @@ const addOperationIds = (document: SwaggerDocument) => { } } +const isOpaqueContractResponse = (response: OpenApiResponse) => { + const content = response.content + if (!isObject(content)) + return false + + return Object.entries(content).some(([mediaType, media]) => { + if (!isObject(media)) + return false + + return (mediaType === 'application/json' || mediaType === 'text/event-stream') && !('schema' in media) + }) +} + +const hasOpaqueContractSuccessResponse = (operation: SwaggerOperation) => { + return Object.entries(operation.responses ?? {}).some(([status, response]) => { + return /^2\d\d$/.test(status) && isObject(response) && isOpaqueContractResponse(response) + }) +} + +const normalizeOpaqueContractResponses = (document: SwaggerDocument) => { + // Some backend endpoints has no schema (e.g. external) and will trap heyapi here + // So we forge an opaque schema here + for (const pathItem of Object.values(document.paths ?? {})) { + for (const [method, operation] of Object.entries(pathItem)) { + if (!operationMethods.has(method) || !isObject(operation)) + continue + + const swaggerOperation = operation as SwaggerOperation + if (!hasOpaqueContractSuccessResponse(swaggerOperation)) + continue + + Object.values(swaggerOperation.responses ?? {}) + .filter(response => isObject(response) && isOpaqueContractResponse(response)) + .forEach((response) => { + response.content = opaqueJsonContent() + }) + } + } +} + const hasSuccessResponse = (operation: SwaggerOperation) => { return Object.entries(operation.responses ?? {}).some(([status, response]) => { if (!/^2\d\d$/.test(status)) @@ -215,6 +274,7 @@ const filterContractOperations = (document: SwaggerDocument) => { } const normalizeApiSwagger = (document: SwaggerDocument) => { + normalizeOpaqueContractResponses(document) filterContractOperations(document) addOperationIds(document) @@ -380,10 +440,20 @@ const createApiConfig = (job: ApiJob): UserConfig => ({ 'name': 'zod', '~resolvers': { string: (ctx) => { - if (ctx.schema.format !== 'binary') - return undefined + if (ctx.schema.format === 'binary') + return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File'))) - return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File'))) + if (ctx.schema.pattern === pydanticDecimalStringPattern) { + // the pydantic generated regex will emit error like + // regexp/no-useless-assertions, so patch the regex here + return $(ctx.symbols.z) + .attr('string') + .call() + .attr('regex') + .call($.regexp(codegenSafeDecimalStringPattern)) + } + + return undefined }, }, },