mirror of
https://github.com/langgenius/dify.git
synced 2026-06-26 14:51:13 +08:00
refactor(api): migrate web auth endpoints to BaseModel
This commit is contained in:
parent
bb921bcc45
commit
8ec9622e19
@ -9,7 +9,12 @@ from werkzeug.exceptions import Unauthorized
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import get_valid_language
|
||||
from controllers.common.fields import SimpleResultDataResponse, SimpleResultOptionalDataResponse, SimpleResultResponse
|
||||
from controllers.common.fields import (
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultMessageResponse,
|
||||
SimpleResultOptionalDataResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
@ -87,6 +92,7 @@ register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPay
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultMessageResponse,
|
||||
SimpleResultOptionalDataResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
@ -154,16 +160,19 @@ class LoginApi(Resource):
|
||||
if system_features.is_allow_create_workspace and not system_features.license.workspaces.is_available():
|
||||
raise WorkspacesLimitExceeded()
|
||||
else:
|
||||
return {
|
||||
"result": "fail",
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
}
|
||||
return SimpleResultOptionalDataResponse(
|
||||
result="fail",
|
||||
data="workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
).model_dump(mode="json")
|
||||
|
||||
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
# response-contract:ignore cookie-bearing Flask response
|
||||
response = make_response(
|
||||
SimpleResultOptionalDataResponse(result="success").model_dump(mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
@ -178,12 +187,11 @@ class LogoutApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
# response-contract:ignore cookie-bearing Flask response
|
||||
response = make_response(SimpleResultResponse(result="success").model_dump(mode="json"))
|
||||
if not isinstance(account, flask_login.AnonymousUserMixin):
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Clear cookies on logout
|
||||
clear_access_token_from_cookie(response)
|
||||
@ -219,7 +227,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
return SimpleResultDataResponse(result="success", data=token).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/email-code-login")
|
||||
@ -252,7 +260,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
else:
|
||||
token = AccountService.send_email_code_login_email(account=account, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
return SimpleResultDataResponse(result="success", data=token).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/email-code-login/validity")
|
||||
@ -326,7 +334,8 @@ class EmailCodeLoginApi(Resource):
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
# response-contract:ignore cookie-bearing Flask response
|
||||
response = make_response(SimpleResultResponse(result="success").model_dump(mode="json"))
|
||||
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
# Set HTTP-only secure cookies for tokens
|
||||
@ -338,18 +347,22 @@ class EmailCodeLoginApi(Resource):
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(401, "Unauthorized", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def post(self):
|
||||
# Get refresh token from cookie instead of request body
|
||||
refresh_token = extract_refresh_token(request)
|
||||
|
||||
if not refresh_token:
|
||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||
return SimpleResultMessageResponse(result="fail", message="No refresh token provided").model_dump(
|
||||
mode="json"
|
||||
), 401
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(refresh_token, session=db.session)
|
||||
|
||||
# Create response with new cookies
|
||||
response = make_response({"result": "success"})
|
||||
# response-contract:ignore cookie-bearing Flask response
|
||||
response = make_response(SimpleResultResponse(result="success").model_dump(mode="json"))
|
||||
|
||||
# Update cookies with new tokens
|
||||
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
|
||||
@ -357,7 +370,7 @@ class RefreshTokenApi(Resource):
|
||||
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
return SimpleResultMessageResponse(result="fail", message=str(e)).model_dump(mode="json"), 401
|
||||
|
||||
|
||||
def _get_account_with_case_fallback(email: str):
|
||||
|
||||
@ -9,6 +9,7 @@ from werkzeug.exceptions import Unauthorized
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.fields import (
|
||||
AccessTokenData,
|
||||
AccessTokenResultResponse,
|
||||
LoginStatusResponse,
|
||||
SimpleResultDataResponse,
|
||||
@ -115,9 +116,10 @@ class LoginApi(Resource):
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
return AccessTokenResultResponse(result="success", data=AccessTokenData(access_token=token)).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
# this api helps frontend to check whether user is authenticated
|
||||
@ -136,14 +138,12 @@ class LoginStatusApi(Resource):
|
||||
)
|
||||
@web_ns.response(200, "Login status", web_ns.models[LoginStatusResponse.__name__])
|
||||
def get(self):
|
||||
app_code = request.args.get("app_code")
|
||||
user_id = request.args.get("user_id")
|
||||
query = LoginStatusQuery.model_validate(request.args.to_dict(flat=True))
|
||||
app_code = query.app_code
|
||||
user_id = query.user_id
|
||||
token = extract_webapp_access_token(request)
|
||||
if not app_code:
|
||||
return {
|
||||
"logged_in": bool(token),
|
||||
"app_logged_in": False,
|
||||
}
|
||||
return LoginStatusResponse(logged_in=bool(token), app_logged_in=False).model_dump(mode="json")
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
is_public = not dify_config.ENTERPRISE_ENABLED or not WebAppAuthService.is_app_require_permission_check(
|
||||
app_id=app_id
|
||||
@ -165,10 +165,7 @@ class LoginStatusApi(Resource):
|
||||
except Exception:
|
||||
app_logged_in = False
|
||||
|
||||
return {
|
||||
"logged_in": user_logged_in,
|
||||
"app_logged_in": app_logged_in,
|
||||
}
|
||||
return LoginStatusResponse(logged_in=user_logged_in, app_logged_in=app_logged_in).model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/logout")
|
||||
@ -183,7 +180,8 @@ class LogoutApi(Resource):
|
||||
)
|
||||
@web_ns.response(200, "Logout successful", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
response = make_response({"result": "success"})
|
||||
# response-contract:ignore hand-crafted response
|
||||
response = make_response(SimpleResultResponse(result="success").model_dump(mode="json"))
|
||||
# enterprise SSO sets same site to None in https deployment
|
||||
# so we need to logout by calling api
|
||||
clear_webapp_access_token_from_cookie(response, samesite="None")
|
||||
@ -216,9 +214,8 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
account = WebAppAuthService.get_user_through_email(payload.email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||
return SimpleResultDataResponse(result="success", data=token).model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/email-code-login/validity")
|
||||
@ -277,9 +274,10 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
return AccessTokenResultResponse(result="success", data=AccessTokenData(access_token=token)).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
|
||||
@ -2,7 +2,7 @@ import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func, select
|
||||
@ -10,11 +10,12 @@ from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
from controllers.common.fields import AccessTokenData
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_access_token
|
||||
from models.enums import EndUserType
|
||||
@ -28,7 +29,13 @@ class PassportQuery(BaseModel):
|
||||
|
||||
|
||||
register_schema_models(web_ns, PassportQuery)
|
||||
register_response_schema_models(web_ns, AccessTokenData)
|
||||
|
||||
|
||||
class PassportAccessTokenResponse(ResponseModel):
|
||||
access_token: str
|
||||
|
||||
|
||||
register_response_schema_models(web_ns, PassportAccessTokenResponse)
|
||||
|
||||
|
||||
@web_ns.route("/passport")
|
||||
@ -45,7 +52,7 @@ class PassportResource(Resource):
|
||||
404: "Application or user not found",
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Passport retrieved successfully", web_ns.models[AccessTokenData.__name__])
|
||||
@web_ns.response(200, "Passport retrieved successfully", web_ns.models[PassportAccessTokenResponse.__name__])
|
||||
def get(self):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
@ -59,8 +66,11 @@ class PassportResource(Resource):
|
||||
if app_auth_type != WebAppAuthType.PUBLIC:
|
||||
if not enterprise_user_decoded:
|
||||
raise WebAppAuthRequiredError()
|
||||
return exchange_token_for_existing_web_user(
|
||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type
|
||||
return dump_response(
|
||||
PassportAccessTokenResponse,
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded, auth_type=app_auth_type
|
||||
),
|
||||
)
|
||||
|
||||
# get site from db and check if it is normal
|
||||
@ -110,12 +120,7 @@ class PassportResource(Resource):
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
response = make_response(
|
||||
{
|
||||
"access_token": tk,
|
||||
}
|
||||
)
|
||||
return response
|
||||
return dump_response(PassportAccessTokenResponse, {"access_token": tk})
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
|
||||
@ -206,12 +211,7 @@ def exchange_token_for_existing_web_user(
|
||||
"exp": exp,
|
||||
}
|
||||
token: str = PassportService().issue(payload)
|
||||
resp = make_response(
|
||||
{
|
||||
"access_token": token,
|
||||
}
|
||||
)
|
||||
return resp
|
||||
return {"access_token": token}
|
||||
|
||||
|
||||
def _exchange_for_public_app_token(app_model, site, token_decoded):
|
||||
@ -244,12 +244,7 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
resp = make_response(
|
||||
{
|
||||
"access_token": tk,
|
||||
}
|
||||
)
|
||||
return resp
|
||||
return {"access_token": tk}
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
|
||||
@ -8001,6 +8001,7 @@ Initiate OAuth login process
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Success | **application/json**: [SimpleResultResponse](#simpleresultresponse)<br> |
|
||||
| 401 | Unauthorized | **application/json**: [SimpleResultMessageResponse](#simpleresultmessageresponse)<br> |
|
||||
|
||||
### [POST] /remote-files/upload
|
||||
#### Request Body
|
||||
|
||||
@ -600,7 +600,7 @@ Get authentication passport for web application access
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Passport retrieved successfully | **application/json**: [AccessTokenData](#accesstokendata)<br> |
|
||||
| 200 | Passport retrieved successfully | **application/json**: [PassportAccessTokenResponse](#passportaccesstokenresponse)<br> |
|
||||
| 401 | Unauthorized - missing app code or invalid authentication | |
|
||||
| 404 | Application or user not found | |
|
||||
|
||||
@ -1429,6 +1429,12 @@ Form input definition.
|
||||
| text_to_speech | [JSONObject](#jsonobject) | | Yes |
|
||||
| user_input_form | [ [JSONObject](#jsonobject) ] | | Yes |
|
||||
|
||||
#### PassportAccessTokenResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| access_token | string | | Yes |
|
||||
|
||||
#### PassportQuery
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
|
||||
@ -34,12 +34,11 @@ def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch)
|
||||
def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
call_state = {"calls": 0}
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
call_state["calls"] += 1
|
||||
return site if call_state["calls"] == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
@ -53,12 +52,11 @@ def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
call_state = {"calls": 0}
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
call_state["calls"] += 1
|
||||
return site if call_state["calls"] == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
@ -71,14 +69,13 @@ def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1")
|
||||
call_state = {"calls": 0}
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
if _scalar_side_effect.calls == 1:
|
||||
call_state["calls"] += 1
|
||||
if call_state["calls"] == 1:
|
||||
return site
|
||||
if _scalar_side_effect.calls == 2:
|
||||
if call_state["calls"] == 2:
|
||||
return app_model
|
||||
return None
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ class TestEmailCodeLoginApi:
|
||||
):
|
||||
response = EmailCodeLoginApi().post()
|
||||
|
||||
assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}}
|
||||
assert response == {"result": "success", "data": {"access_token": "new-access-token"}}
|
||||
mock_get_user.assert_called_once_with("User@Example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_login.assert_called_once()
|
||||
@ -115,7 +115,7 @@ class TestLoginApi:
|
||||
):
|
||||
response = LoginApi().post()
|
||||
|
||||
assert response.get_json()["data"]["access_token"] == "access-tok"
|
||||
assert response["data"]["access_token"] == "access-tok"
|
||||
mock_auth.assert_called_once()
|
||||
|
||||
@patch(
|
||||
|
||||
@ -33,6 +33,7 @@ class TestDecodeEnterpriseWebappUserId:
|
||||
"user_id": "u1",
|
||||
}
|
||||
result = decode_enterprise_webapp_user_id("valid-jwt")
|
||||
assert result is not None
|
||||
assert result["user_id"] == "u1"
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@ -143,7 +144,7 @@ class TestPassportResource:
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "issued-token"
|
||||
assert response["access_token"] == "issued-token"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@ -167,7 +168,7 @@ class TestPassportResource:
|
||||
with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "reused-token"
|
||||
assert response["access_token"] == "reused-token"
|
||||
# Should not create a new end user
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
|
||||
@ -8,6 +8,11 @@ export type SimpleResultResponse = {
|
||||
result: string
|
||||
}
|
||||
|
||||
export type SimpleResultMessageResponse = {
|
||||
message: string
|
||||
result: string
|
||||
}
|
||||
|
||||
export type PostRefreshTokenData = {
|
||||
body?: never
|
||||
path?: never
|
||||
@ -15,6 +20,12 @@ export type PostRefreshTokenData = {
|
||||
url: '/refresh-token'
|
||||
}
|
||||
|
||||
export type PostRefreshTokenErrors = {
|
||||
401: SimpleResultMessageResponse
|
||||
}
|
||||
|
||||
export type PostRefreshTokenError = PostRefreshTokenErrors[keyof PostRefreshTokenErrors]
|
||||
|
||||
export type PostRefreshTokenResponses = {
|
||||
200: SimpleResultResponse
|
||||
}
|
||||
|
||||
@ -9,6 +9,14 @@ export const zSimpleResultResponse = z.object({
|
||||
result: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* SimpleResultMessageResponse
|
||||
*/
|
||||
export const zSimpleResultMessageResponse = z.object({
|
||||
message: z.string(),
|
||||
result: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* Success
|
||||
*/
|
||||
|
||||
@ -421,6 +421,10 @@ export type Parameters = {
|
||||
user_input_form: Array<JsonObject>
|
||||
}
|
||||
|
||||
export type PassportAccessTokenResponse = {
|
||||
access_token: string
|
||||
}
|
||||
|
||||
export type PassportQuery = {
|
||||
user_id?: string | null
|
||||
}
|
||||
@ -1281,7 +1285,7 @@ export type GetPassportErrors = {
|
||||
}
|
||||
|
||||
export type GetPassportResponses = {
|
||||
200: AccessTokenData
|
||||
200: PassportAccessTokenResponse
|
||||
}
|
||||
|
||||
export type GetPassportResponse = GetPassportResponses[keyof GetPassportResponses]
|
||||
|
||||
@ -498,6 +498,13 @@ export const zMessageMoreLikeThisQuery = z.object({
|
||||
response_mode: z.enum(['blocking', 'streaming']),
|
||||
})
|
||||
|
||||
/**
|
||||
* PassportAccessTokenResponse
|
||||
*/
|
||||
export const zPassportAccessTokenResponse = z.object({
|
||||
access_token: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* PassportQuery
|
||||
*/
|
||||
@ -1178,7 +1185,7 @@ export const zGetPassportQuery = z.object({
|
||||
/**
|
||||
* Passport retrieved successfully
|
||||
*/
|
||||
export const zGetPassportResponse = zAccessTokenData
|
||||
export const zGetPassportResponse = zPassportAccessTokenResponse
|
||||
|
||||
export const zPostRemoteFilesUploadBody = zRemoteFileUploadPayload
|
||||
|
||||
|
||||
@ -10,13 +10,21 @@ type SwaggerSchema = JsonObject & {
|
||||
$ref?: string
|
||||
}
|
||||
|
||||
type OpenApiMediaType = JsonObject & {
|
||||
schema?: SwaggerSchema
|
||||
}
|
||||
|
||||
type OpenApiResponse = JsonObject & {
|
||||
content?: Record<string, OpenApiMediaType>
|
||||
}
|
||||
|
||||
type OpenApiComponents = JsonObject & {
|
||||
schemas?: Record<string, SwaggerSchema>
|
||||
}
|
||||
|
||||
type SwaggerOperation = JsonObject & {
|
||||
operationId?: string
|
||||
responses?: Record<string, unknown>
|
||||
responses?: Record<string, OpenApiResponse>
|
||||
}
|
||||
|
||||
type SwaggerDocument = JsonObject & {
|
||||
@ -52,6 +60,17 @@ const currentDir = path.dirname(fileURLToPath(import.meta.url))
|
||||
const apiOpenApiDir = path.resolve(currentDir, 'openapi')
|
||||
|
||||
const operationMethods = new Set(['delete', 'get', 'patch', 'post', 'put'])
|
||||
const pydanticDecimalStringPattern = '^(?!^[-+.]*$)[+-]?0*\\d*\\.?\\d*$'
|
||||
const codegenSafeDecimalStringPattern = '^(?![-+.]*$)[+-]?0*\\d*\\.?\\d*$'
|
||||
|
||||
const opaqueJsonContent = (): Record<string, OpenApiMediaType> => ({
|
||||
'application/json': {
|
||||
schema: {
|
||||
additionalProperties: true,
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const apiSpecs: ApiSpec[] = [
|
||||
{ filename: 'console-openapi.json', name: 'console' },
|
||||
@ -182,6 +201,46 @@ const addOperationIds = (document: SwaggerDocument) => {
|
||||
}
|
||||
}
|
||||
|
||||
const isOpaqueContractResponse = (response: OpenApiResponse) => {
|
||||
const content = response.content
|
||||
if (!isObject(content))
|
||||
return false
|
||||
|
||||
return Object.entries(content).some(([mediaType, media]) => {
|
||||
if (!isObject(media))
|
||||
return false
|
||||
|
||||
return (mediaType === 'application/json' || mediaType === 'text/event-stream') && !('schema' in media)
|
||||
})
|
||||
}
|
||||
|
||||
const hasOpaqueContractSuccessResponse = (operation: SwaggerOperation) => {
|
||||
return Object.entries(operation.responses ?? {}).some(([status, response]) => {
|
||||
return /^2\d\d$/.test(status) && isObject(response) && isOpaqueContractResponse(response)
|
||||
})
|
||||
}
|
||||
|
||||
const normalizeOpaqueContractResponses = (document: SwaggerDocument) => {
|
||||
// Some backend endpoints has no schema (e.g. external) and will trap heyapi here
|
||||
// So we forge an opaque schema here
|
||||
for (const pathItem of Object.values(document.paths ?? {})) {
|
||||
for (const [method, operation] of Object.entries(pathItem)) {
|
||||
if (!operationMethods.has(method) || !isObject(operation))
|
||||
continue
|
||||
|
||||
const swaggerOperation = operation as SwaggerOperation
|
||||
if (!hasOpaqueContractSuccessResponse(swaggerOperation))
|
||||
continue
|
||||
|
||||
Object.values(swaggerOperation.responses ?? {})
|
||||
.filter(response => isObject(response) && isOpaqueContractResponse(response))
|
||||
.forEach((response) => {
|
||||
response.content = opaqueJsonContent()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const hasSuccessResponse = (operation: SwaggerOperation) => {
|
||||
return Object.entries(operation.responses ?? {}).some(([status, response]) => {
|
||||
if (!/^2\d\d$/.test(status))
|
||||
@ -215,6 +274,7 @@ const filterContractOperations = (document: SwaggerDocument) => {
|
||||
}
|
||||
|
||||
const normalizeApiSwagger = (document: SwaggerDocument) => {
|
||||
normalizeOpaqueContractResponses(document)
|
||||
filterContractOperations(document)
|
||||
addOperationIds(document)
|
||||
|
||||
@ -380,10 +440,20 @@ const createApiConfig = (job: ApiJob): UserConfig => ({
|
||||
'name': 'zod',
|
||||
'~resolvers': {
|
||||
string: (ctx) => {
|
||||
if (ctx.schema.format !== 'binary')
|
||||
return undefined
|
||||
if (ctx.schema.format === 'binary')
|
||||
return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File')))
|
||||
|
||||
return $(ctx.symbols.z).attr('custom').call().generic($.type.or($.type('Blob'), $.type('File')))
|
||||
if (ctx.schema.pattern === pydanticDecimalStringPattern) {
|
||||
// the pydantic generated regex will emit error like
|
||||
// regexp/no-useless-assertions, so patch the regex here
|
||||
return $(ctx.symbols.z)
|
||||
.attr('string')
|
||||
.call()
|
||||
.attr('regex')
|
||||
.call($.regexp(codegenSafeDecimalStringPattern))
|
||||
}
|
||||
|
||||
return undefined
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user