From 8ec9622e19cf4e55076dc383145e47d68d54b5dc Mon Sep 17 00:00:00 2001 From: chariri Date: Fri, 26 Jun 2026 03:22:24 +0900 Subject: [PATCH] refactor(api): migrate web auth endpoints to BaseModel --- api/controllers/console/auth/login.py | 45 +++++++---- api/controllers/web/login.py | 34 ++++---- api/controllers/web/passport.py | 43 +++++----- api/openapi/markdown/console-openapi.md | 1 + api/openapi/markdown/web-openapi.md | 8 +- .../controllers/web/test_passport.py | 23 +++--- .../controllers/web/test_web_login.py | 4 +- .../controllers/web/test_web_passport.py | 5 +- .../api/console/refresh-token/types.gen.ts | 11 +++ .../api/console/refresh-token/zod.gen.ts | 8 ++ .../contracts/generated/api/web/types.gen.ts | 6 +- .../contracts/generated/api/web/zod.gen.ts | 9 ++- packages/contracts/openapi-ts.api.config.ts | 78 ++++++++++++++++++- 13 files changed, 193 insertions(+), 82 deletions(-) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 81f9ee4bae4..74d1ecc38f7 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -9,7 +9,12 @@ from werkzeug.exceptions import Unauthorized import services from configs import dify_config from constants.languages import get_valid_language -from controllers.common.fields import SimpleResultDataResponse, SimpleResultOptionalDataResponse, SimpleResultResponse +from controllers.common.fields import ( + SimpleResultDataResponse, + SimpleResultMessageResponse, + SimpleResultOptionalDataResponse, + SimpleResultResponse, +) from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( @@ -87,6 +92,7 @@ register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPay register_response_schema_models( console_ns, SimpleResultDataResponse, + SimpleResultMessageResponse, SimpleResultOptionalDataResponse, SimpleResultResponse, ) @@ -154,16 +160,19 @@ class LoginApi(Resource): if system_features.is_allow_create_workspace and not system_features.license.workspaces.is_available(): raise WorkspacesLimitExceeded() else: - return { - "result": "fail", - "data": "workspace not found, please contact system admin to invite you to join in a workspace", - } + return SimpleResultOptionalDataResponse( + result="fail", + data="workspace not found, please contact system admin to invite you to join in a workspace", + ).model_dump(mode="json") token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(normalized_email) # Create response with cookies instead of returning tokens in body - response = make_response({"result": "success"}) + # response-contract:ignore cookie-bearing Flask response + response = make_response( + SimpleResultOptionalDataResponse(result="success").model_dump(mode="json", exclude_none=True) + ) set_access_token_to_cookie(request, response, token_pair.access_token) set_refresh_token_to_cookie(request, response, token_pair.refresh_token) @@ -178,12 +187,11 @@ class LogoutApi(Resource): @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) @with_current_user def post(self, account: Account): - if isinstance(account, flask_login.AnonymousUserMixin): - response = make_response({"result": "success"}) - else: + # response-contract:ignore cookie-bearing Flask response + response = make_response(SimpleResultResponse(result="success").model_dump(mode="json")) + if not isinstance(account, flask_login.AnonymousUserMixin): AccountService.logout(account=account) flask_login.logout_user() - response = make_response({"result": "success"}) # Clear cookies on logout clear_access_token_from_cookie(response) @@ -219,7 +227,7 @@ class ResetPasswordSendEmailApi(Resource): is_allow_register=FeatureService.get_system_features().is_allow_register, ) - return {"result": "success", "data": token} + return SimpleResultDataResponse(result="success", data=token).model_dump(mode="json") @console_ns.route("/email-code-login") @@ -252,7 +260,7 @@ class EmailCodeLoginSendEmailApi(Resource): else: token = AccountService.send_email_code_login_email(account=account, language=language) - return {"result": "success", "data": token} + return SimpleResultDataResponse(result="success", data=token).model_dump(mode="json") @console_ns.route("/email-code-login/validity") @@ -326,7 +334,8 @@ class EmailCodeLoginApi(Resource): AccountService.reset_login_error_rate_limit(user_email) # Create response with cookies instead of returning tokens in body - response = make_response({"result": "success"}) + # response-contract:ignore cookie-bearing Flask response + response = make_response(SimpleResultResponse(result="success").model_dump(mode="json")) set_csrf_token_to_cookie(request, response, token_pair.csrf_token) # Set HTTP-only secure cookies for tokens @@ -338,18 +347,22 @@ class EmailCodeLoginApi(Resource): @console_ns.route("/refresh-token") class RefreshTokenApi(Resource): @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) + @console_ns.response(401, "Unauthorized", console_ns.models[SimpleResultMessageResponse.__name__]) def post(self): # Get refresh token from cookie instead of request body refresh_token = extract_refresh_token(request) if not refresh_token: - return {"result": "fail", "message": "No refresh token provided"}, 401 + return SimpleResultMessageResponse(result="fail", message="No refresh token provided").model_dump( + mode="json" + ), 401 try: new_token_pair = AccountService.refresh_token(refresh_token, session=db.session) # Create response with new cookies - response = make_response({"result": "success"}) + # response-contract:ignore cookie-bearing Flask response + response = make_response(SimpleResultResponse(result="success").model_dump(mode="json")) # Update cookies with new tokens set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token) @@ -357,7 +370,7 @@ class RefreshTokenApi(Resource): set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token) return response except Exception as e: - return {"result": "fail", "message": str(e)}, 401 + return SimpleResultMessageResponse(result="fail", message=str(e)).model_dump(mode="json"), 401 def _get_account_with_case_fallback(email: str): diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 2d8c38f5507..011bb43b880 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -9,6 +9,7 @@ from werkzeug.exceptions import Unauthorized import services from configs import dify_config from controllers.common.fields import ( + AccessTokenData, AccessTokenResultResponse, LoginStatusResponse, SimpleResultDataResponse, @@ -115,9 +116,10 @@ class LoginApi(Resource): raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) - response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) - return response + return AccessTokenResultResponse(result="success", data=AccessTokenData(access_token=token)).model_dump( + mode="json" + ) # this api helps frontend to check whether user is authenticated @@ -136,14 +138,12 @@ class LoginStatusApi(Resource): ) @web_ns.response(200, "Login status", web_ns.models[LoginStatusResponse.__name__]) def get(self): - app_code = request.args.get("app_code") - user_id = request.args.get("user_id") + query = LoginStatusQuery.model_validate(request.args.to_dict(flat=True)) + app_code = query.app_code + user_id = query.user_id token = extract_webapp_access_token(request) if not app_code: - return { - "logged_in": bool(token), - "app_logged_in": False, - } + return LoginStatusResponse(logged_in=bool(token), app_logged_in=False).model_dump(mode="json") app_id = AppService.get_app_id_by_code(app_code) is_public = not dify_config.ENTERPRISE_ENABLED or not WebAppAuthService.is_app_require_permission_check( app_id=app_id @@ -165,10 +165,7 @@ class LoginStatusApi(Resource): except Exception: app_logged_in = False - return { - "logged_in": user_logged_in, - "app_logged_in": app_logged_in, - } + return LoginStatusResponse(logged_in=user_logged_in, app_logged_in=app_logged_in).model_dump(mode="json") @web_ns.route("/logout") @@ -183,7 +180,8 @@ class LogoutApi(Resource): ) @web_ns.response(200, "Logout successful", web_ns.models[SimpleResultResponse.__name__]) def post(self): - response = make_response({"result": "success"}) + # response-contract:ignore hand-crafted response + response = make_response(SimpleResultResponse(result="success").model_dump(mode="json")) # enterprise SSO sets same site to None in https deployment # so we need to logout by calling api clear_webapp_access_token_from_cookie(response, samesite="None") @@ -216,9 +214,8 @@ class EmailCodeLoginSendEmailApi(Resource): account = WebAppAuthService.get_user_through_email(payload.email) if account is None: raise AuthenticationFailedError() - else: - token = WebAppAuthService.send_email_code_login_email(account=account, language=language) - return {"result": "success", "data": token} + token = WebAppAuthService.send_email_code_login_email(account=account, language=language) + return SimpleResultDataResponse(result="success", data=token).model_dump(mode="json") @web_ns.route("/email-code-login/validity") @@ -277,9 +274,10 @@ class EmailCodeLoginApi(Resource): token = WebAppAuthService.login(account=account) AccountService.reset_login_error_rate_limit(user_email) - response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) - return response + return AccessTokenResultResponse(result="success", data=AccessTokenData(access_token=token)).model_dump( + mode="json" + ) def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None: diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 99b75776280..c11ce824731 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -2,7 +2,7 @@ import uuid from datetime import UTC, datetime, timedelta from typing import Any -from flask import make_response, request +from flask import request from flask_restx import Resource from pydantic import BaseModel, Field from sqlalchemy import func, select @@ -10,11 +10,12 @@ from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config from constants import HEADER_NAME_APP_CODE -from controllers.common.fields import AccessTokenData from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models from controllers.web import web_ns from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db +from fields.base import ResponseModel +from libs.helper import dump_response from libs.passport import PassportService from libs.token import extract_webapp_access_token from models.enums import EndUserType @@ -28,7 +29,13 @@ class PassportQuery(BaseModel): register_schema_models(web_ns, PassportQuery) -register_response_schema_models(web_ns, AccessTokenData) + + +class PassportAccessTokenResponse(ResponseModel): + access_token: str + + +register_response_schema_models(web_ns, PassportAccessTokenResponse) @web_ns.route("/passport") @@ -45,7 +52,7 @@ class PassportResource(Resource): 404: "Application or user not found", } ) - @web_ns.response(200, "Passport retrieved successfully", web_ns.models[AccessTokenData.__name__]) + @web_ns.response(200, "Passport retrieved successfully", web_ns.models[PassportAccessTokenResponse.__name__]) def get(self): system_features = FeatureService.get_system_features() app_code = request.headers.get(HEADER_NAME_APP_CODE) @@ -59,8 +66,11 @@ class PassportResource(Resource): if app_auth_type != WebAppAuthType.PUBLIC: if not enterprise_user_decoded: raise WebAppAuthRequiredError() - return exchange_token_for_existing_web_user( - app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type + return dump_response( + PassportAccessTokenResponse, + exchange_token_for_existing_web_user( + app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type + ), ) # get site from db and check if it is normal @@ -110,12 +120,7 @@ class PassportResource(Resource): tk = PassportService().issue(payload) - response = make_response( - { - "access_token": tk, - } - ) - return response + return dump_response(PassportAccessTokenResponse, {"access_token": tk}) def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None: @@ -206,12 +211,7 @@ def exchange_token_for_existing_web_user( "exp": exp, } token: str = PassportService().issue(payload) - resp = make_response( - { - "access_token": token, - } - ) - return resp + return {"access_token": token} def _exchange_for_public_app_token(app_model, site, token_decoded): @@ -244,12 +244,7 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): tk = PassportService().issue(payload) - resp = make_response( - { - "access_token": tk, - } - ) - return resp + return {"access_token": tk} def generate_session_id(): diff --git a/api/openapi/markdown/console-openapi.md b/api/openapi/markdown/console-openapi.md index b3a0b8a6a71..71bde0530bf 100644 --- a/api/openapi/markdown/console-openapi.md +++ b/api/openapi/markdown/console-openapi.md @@ -8001,6 +8001,7 @@ Initiate OAuth login process | Code | Description | Schema | | ---- | ----------- | ------ | | 200 | Success | **application/json**: [SimpleResultResponse](#simpleresultresponse)
| +| 401 | Unauthorized | **application/json**: [SimpleResultMessageResponse](#simpleresultmessageresponse)
| ### [POST] /remote-files/upload #### Request Body diff --git a/api/openapi/markdown/web-openapi.md b/api/openapi/markdown/web-openapi.md index 0f368895ab6..715e7bc8e3b 100644 --- a/api/openapi/markdown/web-openapi.md +++ b/api/openapi/markdown/web-openapi.md @@ -600,7 +600,7 @@ Get authentication passport for web application access | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Passport retrieved successfully | **application/json**: [AccessTokenData](#accesstokendata)
| +| 200 | Passport retrieved successfully | **application/json**: [PassportAccessTokenResponse](#passportaccesstokenresponse)
| | 401 | Unauthorized - missing app code or invalid authentication | | | 404 | Application or user not found | | @@ -1429,6 +1429,12 @@ Form input definition. | text_to_speech | [JSONObject](#jsonobject) | | Yes | | user_input_form | [ [JSONObject](#jsonobject) ] | | Yes | +#### PassportAccessTokenResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| access_token | string | | Yes | + #### PassportQuery | Name | Type | Description | Required | diff --git a/api/tests/unit_tests/controllers/web/test_passport.py b/api/tests/unit_tests/controllers/web/test_passport.py index 58d58626b22..ebc64d94521 100644 --- a/api/tests/unit_tests/controllers/web/test_passport.py +++ b/api/tests/unit_tests/controllers/web/test_passport.py @@ -34,12 +34,11 @@ def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None: site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") app_model = SimpleNamespace(id="a1", status="normal", enable_site=True) + call_state = {"calls": 0} def _scalar_side_effect(*_args, **_kwargs): - if not hasattr(_scalar_side_effect, "calls"): - _scalar_side_effect.calls = 0 - _scalar_side_effect.calls += 1 - return site if _scalar_side_effect.calls == 1 else app_model + call_state["calls"] += 1 + return site if call_state["calls"] == 1 else app_model db_session = SimpleNamespace(scalar=_scalar_side_effect) monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) @@ -53,12 +52,11 @@ def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None: def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None: site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") app_model = SimpleNamespace(id="a1", status="normal", enable_site=True) + call_state = {"calls": 0} def _scalar_side_effect(*_args, **_kwargs): - if not hasattr(_scalar_side_effect, "calls"): - _scalar_side_effect.calls = 0 - _scalar_side_effect.calls += 1 - return site if _scalar_side_effect.calls == 1 else app_model + call_state["calls"] += 1 + return site if call_state["calls"] == 1 else app_model db_session = SimpleNamespace(scalar=_scalar_side_effect) monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) @@ -71,14 +69,13 @@ def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> No def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None: site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1") + call_state = {"calls": 0} def _scalar_side_effect(*_args, **_kwargs): - if not hasattr(_scalar_side_effect, "calls"): - _scalar_side_effect.calls = 0 - _scalar_side_effect.calls += 1 - if _scalar_side_effect.calls == 1: + call_state["calls"] += 1 + if call_state["calls"] == 1: return site - if _scalar_side_effect.calls == 2: + if call_state["calls"] == 2: return app_model return None diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index bfffd5cbb2c..984be6ddba9 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -95,7 +95,7 @@ class TestEmailCodeLoginApi: ): response = EmailCodeLoginApi().post() - assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}} + assert response == {"result": "success", "data": {"access_token": "new-access-token"}} mock_get_user.assert_called_once_with("User@Example.com") mock_revoke_token.assert_called_once_with("token-123") mock_login.assert_called_once() @@ -115,7 +115,7 @@ class TestLoginApi: ): response = LoginApi().post() - assert response.get_json()["data"]["access_token"] == "access-tok" + assert response["data"]["access_token"] == "access-tok" mock_auth.assert_called_once() @patch( diff --git a/api/tests/unit_tests/controllers/web/test_web_passport.py b/api/tests/unit_tests/controllers/web/test_web_passport.py index 19b1d8504a0..4e1a24a4da2 100644 --- a/api/tests/unit_tests/controllers/web/test_web_passport.py +++ b/api/tests/unit_tests/controllers/web/test_web_passport.py @@ -33,6 +33,7 @@ class TestDecodeEnterpriseWebappUserId: "user_id": "u1", } result = decode_enterprise_webapp_user_id("valid-jwt") + assert result is not None assert result["user_id"] == "u1" @patch("controllers.web.passport.PassportService") @@ -143,7 +144,7 @@ class TestPassportResource: with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): response = PassportResource().get() - assert response.get_json()["access_token"] == "issued-token" + assert response["access_token"] == "issued-token" mock_db.session.add.assert_called_once() mock_db.session.commit.assert_called_once() @@ -167,7 +168,7 @@ class TestPassportResource: with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}): response = PassportResource().get() - assert response.get_json()["access_token"] == "reused-token" + assert response["access_token"] == "reused-token" # Should not create a new end user mock_db.session.add.assert_not_called() diff --git a/packages/contracts/generated/api/console/refresh-token/types.gen.ts b/packages/contracts/generated/api/console/refresh-token/types.gen.ts index 81f76a749c4..bace0b5b85b 100644 --- a/packages/contracts/generated/api/console/refresh-token/types.gen.ts +++ b/packages/contracts/generated/api/console/refresh-token/types.gen.ts @@ -8,6 +8,11 @@ export type SimpleResultResponse = { result: string } +export type SimpleResultMessageResponse = { + message: string + result: string +} + export type PostRefreshTokenData = { body?: never path?: never @@ -15,6 +20,12 @@ export type PostRefreshTokenData = { url: '/refresh-token' } +export type PostRefreshTokenErrors = { + 401: SimpleResultMessageResponse +} + +export type PostRefreshTokenError = PostRefreshTokenErrors[keyof PostRefreshTokenErrors] + export type PostRefreshTokenResponses = { 200: SimpleResultResponse } diff --git a/packages/contracts/generated/api/console/refresh-token/zod.gen.ts b/packages/contracts/generated/api/console/refresh-token/zod.gen.ts index f10ac4c9e03..67f4cdff76c 100644 --- a/packages/contracts/generated/api/console/refresh-token/zod.gen.ts +++ b/packages/contracts/generated/api/console/refresh-token/zod.gen.ts @@ -9,6 +9,14 @@ export const zSimpleResultResponse = z.object({ result: z.string(), }) +/** + * SimpleResultMessageResponse + */ +export const zSimpleResultMessageResponse = z.object({ + message: z.string(), + result: z.string(), +}) + /** * Success */ diff --git a/packages/contracts/generated/api/web/types.gen.ts b/packages/contracts/generated/api/web/types.gen.ts index 722b3042841..ea3ed7f0fb5 100644 --- a/packages/contracts/generated/api/web/types.gen.ts +++ b/packages/contracts/generated/api/web/types.gen.ts @@ -421,6 +421,10 @@ export type Parameters = { user_input_form: Array } +export type PassportAccessTokenResponse = { + access_token: string +} + export type PassportQuery = { user_id?: string | null } @@ -1281,7 +1285,7 @@ export type GetPassportErrors = { } export type GetPassportResponses = { - 200: AccessTokenData + 200: PassportAccessTokenResponse } export type GetPassportResponse = GetPassportResponses[keyof GetPassportResponses] diff --git a/packages/contracts/generated/api/web/zod.gen.ts b/packages/contracts/generated/api/web/zod.gen.ts index d555ad5f85c..ee8a6174251 100644 --- a/packages/contracts/generated/api/web/zod.gen.ts +++ b/packages/contracts/generated/api/web/zod.gen.ts @@ -498,6 +498,13 @@ export const zMessageMoreLikeThisQuery = z.object({ response_mode: z.enum(['blocking', 'streaming']), }) +/** + * PassportAccessTokenResponse + */ +export const zPassportAccessTokenResponse = z.object({ + access_token: z.string(), +}) + /** * PassportQuery */ @@ -1178,7 +1185,7 @@ export const zGetPassportQuery = z.object({ /** * Passport retrieved successfully */ -export const zGetPassportResponse = zAccessTokenData +export const zGetPassportResponse = zPassportAccessTokenResponse export const zPostRemoteFilesUploadBody = zRemoteFileUploadPayload diff --git a/packages/contracts/openapi-ts.api.config.ts b/packages/contracts/openapi-ts.api.config.ts index 8fce8a25bd3..1adbf4fda8e 100644 --- a/packages/contracts/openapi-ts.api.config.ts +++ b/packages/contracts/openapi-ts.api.config.ts @@ -10,13 +10,21 @@ type SwaggerSchema = JsonObject & { $ref?: string } +type OpenApiMediaType = JsonObject & { + schema?: SwaggerSchema +} + +type OpenApiResponse = JsonObject & { + content?: Record +} + type OpenApiComponents = JsonObject & { schemas?: Record } type SwaggerOperation = JsonObject & { operationId?: string - responses?: Record + responses?: Record } type SwaggerDocument = JsonObject & { @@ -52,6 +60,17 @@ const currentDir = path.dirname(fileURLToPath(import.meta.url)) const apiOpenApiDir = path.resolve(currentDir, 'openapi') const operationMethods = new Set(['delete', 'get', 'patch', 'post', 'put']) +const pydanticDecimalStringPattern = '^(?!^[-+.]*$)[+-]?0*\\d*\\.?\\d*$' +const codegenSafeDecimalStringPattern = '^(?![-+.]*$)[+-]?0*\\d*\\.?\\d*$' + +const opaqueJsonContent = (): Record => ({ + 'application/json': { + schema: { + additionalProperties: true, + type: 'object', + }, + }, +}) const apiSpecs: ApiSpec[] = [ { filename: 'console-openapi.json', name: 'console' }, @@ -182,6 +201,46 @@ const addOperationIds = (document: SwaggerDocument) => { } } +const isOpaqueContractResponse = (response: OpenApiResponse) => { + const content = response.content + if (!isObject(content)) + return false + + return Object.entries(content).some(([mediaType, media]) => { + if (!isObject(media)) + return false + + return (mediaType === 'application/json' || mediaType === 'text/event-stream') && !('schema' in media) + }) +} + +const hasOpaqueContractSuccessResponse = (operation: SwaggerOperation) => { + return Object.entries(operation.responses ?? {}).some(([status, response]) => { + return /^2\d\d$/.test(status) && isObject(response) && isOpaqueContractResponse(response) + }) +} + +const normalizeOpaqueContractResponses = (document: SwaggerDocument) => { + // Some backend endpoints has no schema (e.g. external) and will trap heyapi here + // So we forge an opaque schema here + for (const pathItem of Object.values(document.paths ?? {})) { + for (const [method, operation] of Object.entries(pathItem)) { + if (!operationMethods.has(method) || !isObject(operation)) + continue + + const swaggerOperation = operation as SwaggerOperation + if (!hasOpaqueContractSuccessResponse(swaggerOperation)) + continue + + Object.values(swaggerOperation.responses ?? {}) + .filter(response => isObject(response) && isOpaqueContractResponse(response)) + .forEach((response) => { + response.content = opaqueJsonContent() + }) + } + } +} + const hasSuccessResponse = (operation: SwaggerOperation) => { return Object.entries(operation.responses ?? {}).some(([status, response]) => { if (!/^2\d\d$/.test(status)) @@ -215,6 +274,7 @@ const filterContractOperations = (document: SwaggerDocument) => { } const normalizeApiSwagger = (document: SwaggerDocument) => { + normalizeOpaqueContractResponses(document) filterContractOperations(document) addOperationIds(document) @@ -380,10 +440,20 @@ const createApiConfig = (job: ApiJob): UserConfig => ({ 'name': 'zod', '~resolvers': { string: (ctx) => { - if (ctx.schema.format !== 'binary') - return undefined + if (ctx.schema.format === 'binary') + return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File'))) - return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File'))) + if (ctx.schema.pattern === pydanticDecimalStringPattern) { + // the pydantic generated regex will emit error like + // regexp/no-useless-assertions, so patch the regex here + return $(ctx.symbols.z) + .attr('string') + .call() + .attr('regex') + .call($.regexp(codegenSafeDecimalStringPattern)) + } + + return undefined }, }, },