mirror of
https://github.com/langgenius/dify.git
synced 2026-06-26 06:41:10 +08:00
refactor(api): migrate response contract tooling to BaseModel
This commit is contained in:
parent
bb921bcc45
commit
c7051f7af8
@ -20,7 +20,7 @@ openapi_ns = Namespace("openapi", description="User-scoped operations", path="/"
|
||||
|
||||
# Register response/query models BEFORE importing controller modules so that
|
||||
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
|
||||
from controllers.common.fields import EventStreamResponse
|
||||
from controllers.common.fields import EventStreamResponse, SimpleResultResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
@ -95,6 +95,7 @@ register_response_schema_models(
|
||||
openapi_ns,
|
||||
ErrorBody,
|
||||
EventStreamResponse,
|
||||
SimpleResultResponse,
|
||||
UsageInfo,
|
||||
MessageMetadata,
|
||||
AppListRow,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
@ -61,7 +61,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
def _translate_service_errors() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
@ -166,6 +166,7 @@ class AppRunApi(Resource):
|
||||
surface="apps",
|
||||
)
|
||||
|
||||
# response-contract:ignore compact_generate_response
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
|
||||
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
|
||||
This checker intentionally stays conservative. It only reports a hard schema
|
||||
mismatch when both sides are statically known for the same 2xx status code:
|
||||
a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)``
|
||||
or ``Model.model_validate(...).model_dump()`` return.
|
||||
a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)``,
|
||||
``Model(...).model_dump()``, or ``Model.model_validate(...).model_dump()`` return.
|
||||
|
||||
Raw dictionaries, raw lists, ``None`` responses, streaming helpers, missing
|
||||
response schemas, and returns with non-literal status codes are classified as
|
||||
@ -28,6 +28,7 @@ from typing import Any, Literal
|
||||
HTTP_METHODS = {"delete", "get", "head", "options", "patch", "post", "put"}
|
||||
NO_BODY_STATUSES = {HTTPStatus.NO_CONTENT.value, HTTPStatus.RESET_CONTENT.value, HTTPStatus.NOT_MODIFIED.value}
|
||||
DEFAULT_CONTROLLER_DIRS = ("controllers/console", "controllers/service_api", "controllers/web")
|
||||
IGNORE_COMMENT_MARKERS = ("response-contract:ignore",)
|
||||
|
||||
type Classification = Literal["valid", "mismatch", "unknown", "refactorable"]
|
||||
type ActualKind = Literal[
|
||||
@ -41,6 +42,7 @@ type ActualKind = Literal[
|
||||
"unknown",
|
||||
]
|
||||
type MethodNode = ast.FunctionDef | ast.AsyncFunctionDef
|
||||
type ModelValueSource = Literal["constructor", "model_validate"]
|
||||
|
||||
HTTP_STATUS_NAMES = {status.name: status.value for status in HTTPStatus}
|
||||
HTTP_STATUS_NAMES.update({f"HTTP_{status.value}_{status.name}": status.value for status in HTTPStatus})
|
||||
@ -109,18 +111,22 @@ class VariableAssignmentSummary:
|
||||
"""Track whether a local name is safe to treat as one specific response model."""
|
||||
|
||||
known_models: set[str] = field(default_factory=set)
|
||||
known_sources: set[ModelValueSource] = field(default_factory=set)
|
||||
has_unknown_assignment: bool = False
|
||||
|
||||
def add_known(self, model: str) -> None:
|
||||
def add_known(self, model: str, source: ModelValueSource) -> None:
|
||||
self.known_models.add(model)
|
||||
self.known_sources.add(source)
|
||||
|
||||
def add_unknown(self) -> None:
|
||||
self.has_unknown_assignment = True
|
||||
|
||||
def single_known_model(self) -> str | None:
|
||||
def single_known_model(self) -> tuple[str, ModelValueSource] | None:
|
||||
if self.has_unknown_assignment or len(self.known_models) != 1:
|
||||
return None
|
||||
return next(iter(self.known_models))
|
||||
model = next(iter(self.known_models))
|
||||
source: ModelValueSource = "constructor" if self.known_sources == {"constructor"} else "model_validate"
|
||||
return model, source
|
||||
|
||||
|
||||
def dotted_name(node: ast.AST) -> str | None:
|
||||
@ -249,6 +255,12 @@ def model_name_from_model_validate_call(node: ast.AST) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def model_value_from_model_validate_call(node: ast.AST) -> tuple[str, ModelValueSource] | None:
|
||||
if model_name := model_name_from_model_validate_call(node):
|
||||
return model_name, "model_validate"
|
||||
return None
|
||||
|
||||
|
||||
def model_name_from_constructor_call(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call):
|
||||
return None
|
||||
@ -257,6 +269,12 @@ def model_name_from_constructor_call(node: ast.AST) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def model_value_from_constructor_call(node: ast.AST) -> tuple[str, ModelValueSource] | None:
|
||||
if model_name := model_name_from_constructor_call(node):
|
||||
return model_name, "constructor"
|
||||
return None
|
||||
|
||||
|
||||
def model_name_from_model_dump(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute) or node.func.attr != "model_dump":
|
||||
return None
|
||||
@ -272,6 +290,10 @@ def model_name_from_model_value(node: ast.AST) -> str | None:
|
||||
return model_name_from_model_validate_call(node) or model_name_from_constructor_call(node)
|
||||
|
||||
|
||||
def model_value_from_model_value(node: ast.AST) -> tuple[str, ModelValueSource] | None:
|
||||
return model_value_from_model_validate_call(node) or model_value_from_constructor_call(node)
|
||||
|
||||
|
||||
def model_name_from_dump_response(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call):
|
||||
return None
|
||||
@ -287,7 +309,7 @@ def model_name_from_dump_response(node: ast.AST) -> str | None:
|
||||
|
||||
|
||||
def actual_kind_from_expr(
|
||||
expr: ast.AST | None, variable_models: dict[str, str] | None = None
|
||||
expr: ast.AST | None, variable_models: dict[str, tuple[str, ModelValueSource]] | None = None
|
||||
) -> tuple[ActualKind, str | None]:
|
||||
if expr is None:
|
||||
return "none", None
|
||||
@ -299,10 +321,14 @@ def actual_kind_from_expr(
|
||||
if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "model_dump":
|
||||
dumped_value = expr.func.value
|
||||
if isinstance(dumped_value, ast.Name) and variable_models:
|
||||
# A variable dump can match today, but it bypasses dump_response and
|
||||
# is easier to drift; keep it visible as refactorable.
|
||||
model_name = variable_models.get(dumped_value.id)
|
||||
if model_name:
|
||||
model_assignment = variable_models.get(dumped_value.id)
|
||||
if model_assignment:
|
||||
model_name, source = model_assignment
|
||||
if source == "constructor":
|
||||
return "model", model_name
|
||||
# A variable dump from model_validate can match today, but it
|
||||
# bypasses dump_response and is easier to drift; keep it visible
|
||||
# as refactorable.
|
||||
return "model_dump_variable", model_name
|
||||
|
||||
model_dump_model = model_name_from_model_dump(expr)
|
||||
@ -325,7 +351,9 @@ def actual_kind_from_expr(
|
||||
return "unknown", None
|
||||
|
||||
|
||||
def actual_response_from_return(return_node: ast.Return, variable_models: dict[str, str]) -> ActualResponse:
|
||||
def actual_response_from_return(
|
||||
return_node: ast.Return, variable_models: dict[str, tuple[str, ModelValueSource]]
|
||||
) -> ActualResponse:
|
||||
status: int | None = 200
|
||||
body_expr = return_node.value
|
||||
|
||||
@ -363,18 +391,21 @@ def target_names(target: ast.AST) -> Iterable[str]:
|
||||
|
||||
|
||||
def record_assignment(
|
||||
assignments: defaultdict[str, VariableAssignmentSummary], targets: Iterable[str], model_name: str | None
|
||||
assignments: defaultdict[str, VariableAssignmentSummary],
|
||||
targets: Iterable[str],
|
||||
model_assignment: tuple[str, ModelValueSource] | None,
|
||||
) -> None:
|
||||
for target in targets:
|
||||
if model_name is None:
|
||||
if model_assignment is None:
|
||||
# Once a name receives an unknown value, later model_dump() calls on it
|
||||
# are no longer a reliable signal for the returned schema.
|
||||
assignments[target].add_unknown()
|
||||
else:
|
||||
assignments[target].add_known(model_name)
|
||||
model_name, source = model_assignment
|
||||
assignments[target].add_known(model_name, source)
|
||||
|
||||
|
||||
def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]:
|
||||
def variable_model_assignments_for_method(method: MethodNode) -> dict[str, tuple[str, ModelValueSource]]:
|
||||
"""Infer local variables that are unambiguously assigned one response model."""
|
||||
|
||||
assignments: defaultdict[str, VariableAssignmentSummary] = defaultdict(VariableAssignmentSummary)
|
||||
@ -385,10 +416,10 @@ def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]:
|
||||
record_assignment(
|
||||
assignments,
|
||||
(name for target in targets for name in target_names(target)),
|
||||
model_name_from_model_value(value),
|
||||
model_value_from_model_value(value),
|
||||
)
|
||||
case ast.AnnAssign(target=target, value=value) if value is not None:
|
||||
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
|
||||
record_assignment(assignments, target_names(target), model_value_from_model_value(value))
|
||||
case ast.AugAssign(target=target) | ast.For(target=target) | ast.AsyncFor(target=target):
|
||||
# Mutation and loop targets overwrite prior values with runtime-dependent data.
|
||||
record_assignment(assignments, target_names(target), None)
|
||||
@ -399,9 +430,13 @@ def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]:
|
||||
case ast.ExceptHandler(name=name) if name:
|
||||
assignments[name].add_unknown()
|
||||
case ast.NamedExpr(target=target, value=value):
|
||||
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
|
||||
record_assignment(assignments, target_names(target), model_value_from_model_value(value))
|
||||
|
||||
return {name: model for name, summary in assignments.items() if (model := summary.single_known_model()) is not None}
|
||||
return {
|
||||
name: assignment
|
||||
for name, summary in assignments.items()
|
||||
if (assignment := summary.single_known_model()) is not None
|
||||
}
|
||||
|
||||
|
||||
def actual_responses_for_method(method: MethodNode) -> list[ActualResponse]:
|
||||
@ -545,13 +580,52 @@ def iter_controller_files(paths: Iterable[Path]) -> Iterable[Path]:
|
||||
yield from sorted(child for child in path.rglob("*.py") if child.is_file())
|
||||
|
||||
|
||||
def node_start_lineno(node: ast.ClassDef | MethodNode) -> int:
|
||||
decorator_lines = [decorator.lineno for decorator in node.decorator_list]
|
||||
if decorator_lines:
|
||||
return min(decorator_lines)
|
||||
return node.lineno
|
||||
|
||||
|
||||
def line_has_ignore_marker(line: str) -> bool:
|
||||
_, marker, comment = line.partition("#")
|
||||
if not marker:
|
||||
return False
|
||||
normalized = comment.lower()
|
||||
return any(ignore_marker in normalized for ignore_marker in IGNORE_COMMENT_MARKERS)
|
||||
|
||||
|
||||
def node_has_ignore_comment(lines: Sequence[str], node: ast.ClassDef | MethodNode) -> bool:
|
||||
start = node_start_lineno(node)
|
||||
end = node.end_lineno or node.lineno
|
||||
if any(line_has_ignore_marker(line) for line in lines[start - 1 : end]):
|
||||
return True
|
||||
|
||||
line_index = start - 2
|
||||
while line_index >= 0:
|
||||
stripped = lines[line_index].strip()
|
||||
if not stripped:
|
||||
line_index -= 1
|
||||
continue
|
||||
if not stripped.startswith("#"):
|
||||
break
|
||||
if line_has_ignore_marker(lines[line_index]):
|
||||
return True
|
||||
line_index -= 1
|
||||
return False
|
||||
|
||||
|
||||
def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]:
|
||||
module = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path))
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
lines = source.splitlines()
|
||||
module = ast.parse(source, filename=str(file_path))
|
||||
checks: list[ContractCheck] = []
|
||||
|
||||
for node in module.body:
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
if node_has_ignore_comment(lines, node):
|
||||
continue
|
||||
|
||||
class_routes = routes_from_decorators(node.decorator_list)
|
||||
class_documented = response_docs_from_decorators(node.decorator_list)
|
||||
@ -559,6 +633,8 @@ def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]:
|
||||
for item in node.body:
|
||||
if not isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) or item.name not in HTTP_METHODS:
|
||||
continue
|
||||
if node_has_ignore_comment(lines, item):
|
||||
continue
|
||||
|
||||
routes = routes_from_decorators(item.decorator_list) or class_routes
|
||||
if not routes:
|
||||
|
||||
@ -7,7 +7,8 @@ class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
validate_by_name=True,
|
||||
validate_by_alias=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
@ -990,6 +990,12 @@ Pagination for GET /account/sessions. Strict (extra='forbid').
|
||||
| last_used_at | string | | No |
|
||||
| prefix | string | | Yes |
|
||||
|
||||
#### SimpleResultResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| result | string | | Yes |
|
||||
|
||||
#### SupportedAppType
|
||||
|
||||
App types the ``app`` usage face (``get app``) lists and filters.
|
||||
|
||||
@ -77,6 +77,25 @@ class AnnotationApi(Resource):
|
||||
assert "prefer dump_response" in checks[0].reason
|
||||
|
||||
|
||||
def test_constructor_variable_model_dump_is_valid(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
@ns.route("/annotations")
|
||||
class AnnotationApi(Resource):
|
||||
@ns.response(201, "Created", ns.models[AnnotationResponse.__name__])
|
||||
def post(self):
|
||||
response = AnnotationResponse(id="new", name=name)
|
||||
return response.model_dump(mode="json"), 201
|
||||
""",
|
||||
)
|
||||
|
||||
assert len(checks) == 1
|
||||
assert checks[0].classification == "valid"
|
||||
assert checks[0].actual[0].kind == "model"
|
||||
assert checks[0].actual[0].model == "AnnotationResponse"
|
||||
|
||||
|
||||
def test_variable_model_dump_with_wrong_documented_schema_is_mismatch(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
@ -117,6 +136,38 @@ class StreamApi(Resource):
|
||||
assert {actual.model for actual in checks[0].actual} == {"StreamResponse"}
|
||||
|
||||
|
||||
def test_response_contract_ignore_comment_skips_route_method(tmp_path: Path):
|
||||
checks = _checks_for_source(
|
||||
tmp_path,
|
||||
"""
|
||||
@ns.route("/binary")
|
||||
class BinaryApi(Resource):
|
||||
# response-contract:ignore binary response
|
||||
@ns.response(200, "Binary file")
|
||||
def get(self):
|
||||
return send_file(path)
|
||||
|
||||
|
||||
# response-contract:ignore compact Flask response
|
||||
@ns.route("/compact")
|
||||
class CompactApi(Resource):
|
||||
def get(self):
|
||||
return make_response({"url": "https://example.com"})
|
||||
|
||||
|
||||
@ns.route("/regular")
|
||||
class RegularApi(Resource):
|
||||
@ns.response(200, "OK", ns.models[RegularResponse.__name__])
|
||||
def get(self):
|
||||
return dump_response(RegularResponse, {})
|
||||
""",
|
||||
)
|
||||
|
||||
assert len(checks) == 1
|
||||
assert checks[0].class_name == "RegularApi"
|
||||
assert checks[0].classification == "valid"
|
||||
|
||||
|
||||
def test_main_is_report_only_by_default_for_mismatches(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
module = _load_lint_response_contracts_module()
|
||||
controller_path = tmp_path / "controllers" / "sample.py"
|
||||
|
||||
@ -3,13 +3,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
|
||||
_TEST_APP_ID = str(uuid.uuid4())
|
||||
_TEST_TENANT_ID = str(uuid.uuid4())
|
||||
_TEST_ACCOUNT_ID = str(uuid.uuid4())
|
||||
|
||||
|
||||
def _make_app() -> App:
|
||||
app = App()
|
||||
app.id = _TEST_APP_ID
|
||||
app.tenant_id = _TEST_TENANT_ID
|
||||
app.name = "Streaming app"
|
||||
app.mode = AppMode.CHAT
|
||||
app.enable_site = False
|
||||
app.enable_api = True
|
||||
return app
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(name="OpenAPI caller", email="caller@example.com")
|
||||
account.id = _TEST_ACCOUNT_ID
|
||||
return account
|
||||
|
||||
|
||||
def test_app_run_request_has_no_response_mode_field():
|
||||
@ -40,15 +63,19 @@ def test_run_chat_always_calls_generate_with_streaming_true(
|
||||
from controllers.openapi.app_run import _run_chat
|
||||
|
||||
generate_mock = Mock(return_value=iter([]))
|
||||
|
||||
class GenerateService:
|
||||
generate = generate_mock
|
||||
|
||||
monkeypatch.setattr(
|
||||
sys.modules["controllers.openapi.app_run"],
|
||||
"AppGenerateService",
|
||||
SimpleNamespace(generate=generate_mock),
|
||||
GenerateService,
|
||||
)
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/run", method="POST"):
|
||||
with app.test_request_context(f"/openapi/v1/apps/{_TEST_APP_ID}/run", method="POST"):
|
||||
_run_chat(
|
||||
SimpleNamespace(id="app-1", tenant_id="t-1"),
|
||||
SimpleNamespace(id="acct-1"),
|
||||
_make_app(),
|
||||
_make_account(),
|
||||
AppRunRequest(inputs={}, query="hello"),
|
||||
)
|
||||
_, kwargs = generate_mock.call_args
|
||||
@ -80,11 +107,11 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app: Flask, bypass_pipel
|
||||
|
||||
auth_data = AuthData.model_construct(
|
||||
token_type=TokenType.OAUTH_ACCOUNT,
|
||||
account_id=uuid.uuid4(),
|
||||
account_id=uuid.UUID(_TEST_ACCOUNT_ID),
|
||||
token_hash="test",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
app=SimpleNamespace(id="app-1", tenant_id="t-1"),
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
app=_make_app(),
|
||||
caller=_make_account(),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ view function decorated with @accepts/@returns, driven inside a request context.
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@ -100,7 +101,7 @@ def test_accepts_validation_error_is_sanitized_and_structured(app):
|
||||
with pytest.raises(UnprocessableEntity) as exc_info:
|
||||
view()
|
||||
|
||||
data = exc_info.value.data
|
||||
data = cast(dict[str, Any], cast(Any, exc_info.value).data)
|
||||
assert data["message"] == "Request validation failed"
|
||||
assert isinstance(data["errors"], list)
|
||||
assert data["errors"]
|
||||
|
||||
@ -405,6 +405,10 @@ export type SessionRow = {
|
||||
prefix: string
|
||||
}
|
||||
|
||||
export type SimpleResultResponse = {
|
||||
result: string
|
||||
}
|
||||
|
||||
export type SupportedAppType = 'advanced-chat' | 'agent-chat' | 'chat' | 'completion' | 'workflow'
|
||||
|
||||
export type TaskStopResponse = {
|
||||
|
||||
@ -501,6 +501,13 @@ export const zSessionListResponse = z.object({
|
||||
total: z.int(),
|
||||
})
|
||||
|
||||
/**
|
||||
* SimpleResultResponse
|
||||
*/
|
||||
export const zSimpleResultResponse = z.object({
|
||||
result: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* SupportedAppType
|
||||
*
|
||||
|
||||
@ -10,13 +10,21 @@ type SwaggerSchema = JsonObject & {
|
||||
$ref?: string
|
||||
}
|
||||
|
||||
type OpenApiMediaType = JsonObject & {
|
||||
schema?: SwaggerSchema
|
||||
}
|
||||
|
||||
type OpenApiResponse = JsonObject & {
|
||||
content?: Record<string, OpenApiMediaType>
|
||||
}
|
||||
|
||||
type OpenApiComponents = JsonObject & {
|
||||
schemas?: Record<string, SwaggerSchema>
|
||||
}
|
||||
|
||||
type SwaggerOperation = JsonObject & {
|
||||
operationId?: string
|
||||
responses?: Record<string, unknown>
|
||||
responses?: Record<string, OpenApiResponse>
|
||||
}
|
||||
|
||||
type SwaggerDocument = JsonObject & {
|
||||
@ -52,6 +60,17 @@ const currentDir = path.dirname(fileURLToPath(import.meta.url))
|
||||
const apiOpenApiDir = path.resolve(currentDir, 'openapi')
|
||||
|
||||
const operationMethods = new Set(['delete', 'get', 'patch', 'post', 'put'])
|
||||
const pydanticDecimalStringPattern = '^(?!^[-+.]*$)[+-]?0*\\d*\\.?\\d*$'
|
||||
const codegenSafeDecimalStringPattern = '^(?![-+.]*$)[+-]?0*\\d*\\.?\\d*$'
|
||||
|
||||
const opaqueJsonContent = (): Record<string, OpenApiMediaType> => ({
|
||||
'application/json': {
|
||||
schema: {
|
||||
additionalProperties: true,
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const apiSpecs: ApiSpec[] = [
|
||||
{ filename: 'console-openapi.json', name: 'console' },
|
||||
@ -182,6 +201,46 @@ const addOperationIds = (document: SwaggerDocument) => {
|
||||
}
|
||||
}
|
||||
|
||||
const isOpaqueContractResponse = (response: OpenApiResponse) => {
|
||||
const content = response.content
|
||||
if (!isObject(content))
|
||||
return false
|
||||
|
||||
return Object.entries(content).some(([mediaType, media]) => {
|
||||
if (!isObject(media))
|
||||
return false
|
||||
|
||||
return (mediaType === 'application/json' || mediaType === 'text/event-stream') && !('schema' in media)
|
||||
})
|
||||
}
|
||||
|
||||
const hasOpaqueContractSuccessResponse = (operation: SwaggerOperation) => {
|
||||
return Object.entries(operation.responses ?? {}).some(([status, response]) => {
|
||||
return /^2\d\d$/.test(status) && isObject(response) && isOpaqueContractResponse(response)
|
||||
})
|
||||
}
|
||||
|
||||
const normalizeOpaqueContractResponses = (document: SwaggerDocument) => {
|
||||
// Some backend endpoints has no schema (e.g. external) and will trap heyapi here
|
||||
// So we forge an opaque schema here
|
||||
for (const pathItem of Object.values(document.paths ?? {})) {
|
||||
for (const [method, operation] of Object.entries(pathItem)) {
|
||||
if (!operationMethods.has(method) || !isObject(operation))
|
||||
continue
|
||||
|
||||
const swaggerOperation = operation as SwaggerOperation
|
||||
if (!hasOpaqueContractSuccessResponse(swaggerOperation))
|
||||
continue
|
||||
|
||||
Object.values(swaggerOperation.responses ?? {})
|
||||
.filter(response => isObject(response) && isOpaqueContractResponse(response))
|
||||
.forEach((response) => {
|
||||
response.content = opaqueJsonContent()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const hasSuccessResponse = (operation: SwaggerOperation) => {
|
||||
return Object.entries(operation.responses ?? {}).some(([status, response]) => {
|
||||
if (!/^2\d\d$/.test(status))
|
||||
@ -215,6 +274,7 @@ const filterContractOperations = (document: SwaggerDocument) => {
|
||||
}
|
||||
|
||||
const normalizeApiSwagger = (document: SwaggerDocument) => {
|
||||
normalizeOpaqueContractResponses(document)
|
||||
filterContractOperations(document)
|
||||
addOperationIds(document)
|
||||
|
||||
@ -380,10 +440,20 @@ const createApiConfig = (job: ApiJob): UserConfig => ({
|
||||
'name': 'zod',
|
||||
'~resolvers': {
|
||||
string: (ctx) => {
|
||||
if (ctx.schema.format !== 'binary')
|
||||
return undefined
|
||||
if (ctx.schema.format === 'binary')
|
||||
return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File')))
|
||||
|
||||
return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File')))
|
||||
if (ctx.schema.pattern === pydanticDecimalStringPattern) {
|
||||
// the pydantic generated regex will emit error like
|
||||
// regexp/no-useless-assertions, so patch the regex here
|
||||
return $(ctx.symbols.z)
|
||||
.attr('string')
|
||||
.call()
|
||||
.attr('regex')
|
||||
.call($.regexp(codegenSafeDecimalStringPattern))
|
||||
}
|
||||
|
||||
return undefined
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user