refactor(openapi): unify request validation behind @accepts/@returns decorators (#37216)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
L1nSn0w 2026-06-10 11:02:24 +08:00 committed by GitHub
parent c9bb740a6b
commit 629e046303
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 461 additions and 157 deletions

View File

@ -0,0 +1,81 @@
"""Request/response contract decorators for the openapi controllers.
``@accepts`` and ``@returns`` own one slice of the contract from a single model
reference emitting the Swagger schema AND doing the runtime validation/
serialisation so the advertised and enforced contracts can't drift. Validation
failures map to a single shape: 422.
They must sit BELOW ``@auth_router.guard`` so auth runs before validation and the
``view.__wrapped__`` unit-test seam unwraps exactly the guard layer.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import Any
from flask import request
from flask_restx import abort
from pydantic import BaseModel, ValidationError
from controllers.common.schema import query_params_from_model, query_params_from_request
from controllers.openapi import openapi_ns
def accepts(*, query: type[BaseModel] | None = None, body: type[BaseModel] | None = None) -> Callable:
"""Validate ``query``/``body`` against the models and inject them as keyword-only kwargs.
Emits the matching Swagger schema from the same models, so doc and enforcement
stay in lockstep.
"""
def decorator(view: Callable) -> Callable:
@wraps(view)
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
if query is not None:
kwargs["query"] = query_params_from_request(query)
if body is not None:
kwargs["body"] = body.model_validate(request.get_json(silent=True) or {})
except ValidationError as exc:
# Sanitized 422 — no pydantic `url` (version) or `input` (user payload) leak.
abort(
422,
message="Request validation failed",
errors=exc.errors(include_url=False, include_input=False, include_context=False),
)
return view(*args, **kwargs)
if query is not None:
openapi_ns.doc(params=query_params_from_model(query))(wrapper)
if body is not None:
openapi_ns.expect(openapi_ns.models[body.__name__])(wrapper)
return wrapper
return decorator
def returns(code: int, model: type[BaseModel], description: str | None = None) -> Callable:
"""Serialise the handler's returned model and emit the response schema.
Accepts a ``BaseModel`` (serialised with ``code``) or a ``(model, status[, headers])``
tuple (status/headers honoured). Other returns a bare ``(dict, status)``, an SSE
``Response`` pass through untouched.
"""
def decorator(view: Callable) -> Callable:
@wraps(view)
def wrapper(*args: Any, **kwargs: Any) -> Any:
result = view(*args, **kwargs)
if isinstance(result, BaseModel):
return result.model_dump(mode="json"), code
if isinstance(result, tuple) and result and isinstance(result[0], BaseModel):
payload, *rest = result
return (payload.model_dump(mode="json"), *rest)
return result
openapi_ns.response(code, description or model.__name__, openapi_ns.models[model.__name__])(wrapper)
return wrapper
return decorator

View File

@ -9,15 +9,16 @@ from flask_restx import Resource
from configs import dify_config
from controllers.openapi import openapi_ns
from controllers.openapi._contract import returns
from controllers.openapi._models import ServerVersionResponse
@openapi_ns.route("/_version")
class VersionApi(Resource):
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
@returns(200, ServerVersionResponse, description="Server version")
def get(self):
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
return ServerVersionResponse(
version=dify_config.project.version,
edition=edition,
).model_dump(mode="json")
)

View File

@ -2,17 +2,14 @@ from __future__ import annotations
from datetime import UTC, datetime
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import NotFound, UnprocessableEntity
from werkzeug.exceptions import NotFound
from controllers.common.schema import query_params_from_model
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._models import (
AccountPayload,
AccountResponse,
PaginationEnvelope,
RevokeResponse,
SessionListQuery,
SessionListResponse,
@ -42,8 +39,8 @@ from services.oauth_device_flow import (
@openapi_ns.route("/account")
class AccountApi(Resource):
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@returns(200, AccountResponse, description="Account info")
def get(self, *, auth_data: AuthData):
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
@ -58,31 +55,27 @@ class AccountApi(Resource):
account=_account_payload(account) if account else None,
workspaces=[_workspace_payload(m) for m in memberships],
default_workspace_id=default_ws_id,
).model_dump(mode="json")
)
@openapi_ns.route("/account/sessions/self")
class AccountSessionsSelfApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@returns(200, RevokeResponse, description="Session revoked")
def delete(self, *, auth_data: AuthData):
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
return RevokeResponse(status="revoked")
@openapi_ns.route("/account/sessions")
class AccountSessionsApi(Resource):
@openapi_ns.doc(params=query_params_from_model(SessionListQuery))
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
# Validate page/limit through the same model the contract advertises (extra='forbid',
# page>=1, 1<=limit<=MAX_PAGE_LIMIT) so the server actually enforces those bounds rather
# than silently coercing (e.g. page=0 -> empty slice). Mirrors AppDescribeQuery.
try:
query = SessionListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
@returns(200, SessionListResponse, description="Session list")
@accepts(query=SessionListQuery)
def get(self, *, auth_data: AuthData, query: SessionListQuery):
# SessionListQuery enforces the advertised bounds (extra='forbid', page>=1,
# 1<=limit<=MAX_PAGE_LIMIT) so the server rejects out-of-range paging rather
# than silently coercing (e.g. page=0 -> empty slice).
ctx = get_auth_ctx()
now = datetime.now(UTC)
page = query.page
@ -106,16 +99,19 @@ class AccountSessionsApi(Resource):
for r in sliced
]
return (
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
200,
return SessionListResponse(
page=page,
limit=limit,
total=total,
has_more=page * limit < total,
data=items,
)
@openapi_ns.route("/account/sessions/<string:session_id>")
class AccountSessionByIdApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@returns(200, RevokeResponse, description="Session revoked")
def delete(self, session_id: str, *, auth_data: AuthData):
ctx = get_auth_ctx()
@ -125,7 +121,7 @@ class AccountSessionByIdApi(Resource):
raise NotFound("session not found")
revoke_oauth_token(db.session, redis_client, session_id)
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
return RevokeResponse(status="revoked")
def _iso(dt: datetime | None) -> str | None:

View File

@ -7,14 +7,13 @@ from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._contract import accepts, returns
from controllers.openapi._models import AppRunRequest, TaskStopResponse
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
@ -123,23 +122,18 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
@openapi_ns.route("/apps/<string:app_id>/run")
class AppRunApi(Resource):
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
@openapi_ns.response(200, "Run result (SSE stream)")
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, *, auth_data: AuthData):
@openapi_ns.response(200, "Run result (SSE stream)")
@accepts(body=AppRunRequest)
def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest):
app_model, caller, caller_kind = auth_data.require_app_context()
body = request.get_json(silent=True) or {}
try:
payload = AppRunRequest.model_validate(body)
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
handler = _DISPATCH.get(app_model.mode)
if handler is None:
raise UnprocessableEntity("mode_not_runnable")
try:
stream_obj = handler(app_model, caller, payload)
stream_obj = handler(app_model, caller, body)
except HTTPException:
raise
except Exception:
@ -159,10 +153,10 @@ class AppRunApi(Resource):
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
class AppRunTaskStopApi(Resource):
@openapi_ns.response(200, "Task stopped", openapi_ns.models[TaskStopResponse.__name__])
@auth_router.guard(scope=Scope.APPS_RUN)
@returns(200, TaskStopResponse, description="Task stopped")
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
app_model, caller, caller_kind = auth_data.require_app_context()
AppQueueManager.set_stop_flag_no_user_check(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}
return TaskStopResponse(result="success")

View File

@ -5,14 +5,12 @@ from __future__ import annotations
import uuid as _uuid
from typing import Any, cast
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
from controllers.common.fields import Parameters
from controllers.common.schema import query_params_from_model
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
from controllers.openapi._models import (
AppDescribeInfo,
@ -88,15 +86,11 @@ def parameters_payload(app: App) -> dict:
@openapi_ns.route("/apps/<string:app_id>/describe")
class AppDescribeApi(AppReadResource):
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, app_id: str, *, auth_data: AuthData):
try:
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
@returns(200, AppDescribeResponse, description="App description")
@accepts(query=AppDescribeQuery)
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
# describe is UUID-only (workspace_id query param dropped in #37212).
app = self._load(app_id)
requested = query.fields
@ -133,35 +127,22 @@ class AppDescribeApi(AppReadResource):
except AppUnavailableError:
input_schema = dict(EMPTY_INPUT_SCHEMA)
return (
AppDescribeResponse(
info=info,
parameters=parameters,
input_schema=input_schema,
).model_dump(mode="json", exclude_none=False),
200,
return AppDescribeResponse(
info=info,
parameters=parameters,
input_schema=input_schema,
)
@openapi_ns.route("/apps")
class AppListApi(Resource):
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
@returns(200, AppListResponse, description="App list")
@accepts(query=AppListQuery)
def get(self, *, auth_data: AuthData, query: AppListQuery):
workspace_id = query.workspace_id
empty = (
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
mode="json"
),
200,
)
empty = AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[])
if query.name:
try:
@ -189,7 +170,7 @@ class AppListApi(Resource):
workspace_name=tenant_name,
)
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
return env.model_dump(mode="json"), 200
return env
tag_ids: list[str] | None = None
if query.tag:
@ -240,4 +221,4 @@ class AppListApi(Resource):
has_more=query.page * query.limit < cast(int, pagination.total),
data=items,
)
return env.model_dump(mode="json"), 200
return env

View File

@ -7,12 +7,10 @@ EE blueprint chain so this module is unreachable there.
from __future__ import annotations
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._models import (
AppListRow,
PermittedExternalAppsListQuery,
@ -30,20 +28,14 @@ from services.enterprise.app_permitted_service import list_permitted_apps
@openapi_ns.route("/permitted-external-apps")
class PermittedExternalAppsListApi(Resource):
@openapi_ns.response(
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
)
@auth_router.guard(
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
edition=frozenset({Edition.EE}),
)
def get(self, *, auth_data: AuthData):
try:
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
@returns(200, PermittedExternalAppsListResponse, description="Permitted external apps list")
@accepts(query=PermittedExternalAppsListQuery)
def get(self, *, auth_data: AuthData, query: PermittedExternalAppsListQuery):
page_result = list_permitted_apps(
page=query.page,
limit=query.limit,
@ -55,7 +47,7 @@ class PermittedExternalAppsListApi(Resource):
env = PermittedExternalAppsListResponse(
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
)
return env.model_dump(mode="json"), 200
return env
apps_by_id: dict[str, App] = {
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
@ -89,4 +81,4 @@ class PermittedExternalAppsListApi(Resource):
has_more=query.page * query.limit < page_result.total,
data=items,
)
return env.model_dump(mode="json"), 200
return env

View File

@ -17,6 +17,7 @@ from controllers.common.errors import (
UnsupportedFileTypeError,
)
from controllers.openapi import openapi_ns
from controllers.openapi._contract import returns
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
@ -38,8 +39,8 @@ class AppFileUploadApi(Resource):
415: "Unsupported file type or blocked extension",
}
)
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
@auth_router.guard(scope=Scope.APPS_RUN)
@returns(HTTPStatus.CREATED, FileResponse, description="File uploaded")
def post(self, app_id: str, *, auth_data: AuthData):
app_model, caller, _ = auth_data.require_app_context()
if "file" not in request.files:
@ -69,5 +70,4 @@ class AppFileUploadApi(Resource):
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError(exc.description)
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201
return FileResponse.model_validate(upload_file, from_attributes=True)

View File

@ -10,13 +10,14 @@ from __future__ import annotations
import json
import logging
from flask import Response, request
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._models import FormSubmitResponse
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
@ -70,12 +71,11 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
@openapi_ns.response(200, "Form submitted", openapi_ns.models[FormSubmitResponse.__name__])
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
@returns(200, FormSubmitResponse, description="Form submitted")
@accepts(body=HumanInputFormSubmitPayload)
def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload):
app_model, caller, caller_kind = auth_data.require_app_context()
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
@ -100,12 +100,12 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
service.submit_form_by_token(
recipient_type=form.recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
selected_action_id=body.action,
form_data=body.inputs,
submission_user_id=submission_user_id,
submission_end_user_id=submission_end_user_id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200
return FormSubmitResponse()

View File

@ -1,11 +1,12 @@
from flask_restx import Resource
from controllers.openapi import openapi_ns
from controllers.openapi._contract import returns
from controllers.openapi._models import HealthResponse
@openapi_ns.route("/_health")
class HealthApi(Resource):
@openapi_ns.response(200, "Health check", openapi_ns.models[HealthResponse.__name__])
@returns(200, HealthResponse, description="Health check")
def get(self):
return {"ok": True}
return HealthResponse(ok=True)

View File

@ -14,14 +14,13 @@ from __future__ import annotations
from itertools import starmap
from urllib import parse
from flask import jsonify, make_response, request
from flask import jsonify, make_response
from flask_restx import Resource
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from configs import dify_config
from controllers.common.schema import query_params_from_model
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
from controllers.openapi._models import (
MemberActionResponse,
MemberInvitePayload,
@ -53,14 +52,6 @@ from services.errors.account import (
from services.feature_service import FeatureService
def _validate_body[M: BaseModel](model: type[M]) -> M:
body = request.get_json(silent=True) or {}
try:
return model.model_validate(body)
except ValidationError as exc:
raise BadRequest(str(exc))
def _member_response(account: Account) -> MemberResponse:
return MemberResponse(
id=str(account.id),
@ -118,18 +109,18 @@ def _check_member_invite_quota(tenant_id: str) -> None:
@openapi_ns.route("/workspaces")
class WorkspacesApi(Resource):
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@returns(200, WorkspaceListResponse, description="Workspace list")
def get(self, *, auth_data: AuthData):
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows)))
@openapi_ns.route("/workspaces/<string:workspace_id>")
class WorkspaceByIdApi(Resource):
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@returns(200, WorkspaceDetailResponse, description="Workspace detail")
def get(self, workspace_id: str, *, auth_data: AuthData):
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
@ -137,7 +128,7 @@ class WorkspaceByIdApi(Resource):
raise NotFound("workspace not found")
tenant, membership = row
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
return _workspace_detail(tenant, membership)
@openapi_ns.route("/workspaces/<string:workspace_id>/switch")
@ -149,8 +140,8 @@ class WorkspaceSwitchApi(Resource):
that ``hosts.yml`` never diverges from the server's ``current`` state.
"""
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
@returns(200, WorkspaceDetailResponse, description="Workspace detail")
def post(self, workspace_id: str, *, auth_data: AuthData):
account = _load_account(auth_data.account_id)
@ -163,7 +154,7 @@ class WorkspaceSwitchApi(Resource):
if row is None:
raise NotFound("workspace not found")
tenant, membership = row
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
return _workspace_detail(tenant, membership)
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
@ -174,15 +165,10 @@ class WorkspaceMembersApi(Resource):
assigned through invite (ownership transfer is console-only).
"""
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
@auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, workspace_id: str, *, auth_data: AuthData):
try:
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise BadRequest(str(exc))
@returns(200, MemberListResponse, description="Member list")
@accepts(query=MemberListQuery)
def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery):
tenant = _load_tenant(workspace_id)
members = TenantService.get_tenant_members(tenant)
total = len(members)
@ -194,17 +180,16 @@ class WorkspaceMembersApi(Resource):
total=total,
has_more=query.page * query.limit < total,
data=[_member_response(m) for m in page_items],
).model_dump(mode="json"), 200
)
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
def post(self, workspace_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberInvitePayload)
@returns(201, MemberInviteResponse, description="Member invited")
@accepts(body=MemberInvitePayload)
def post(self, workspace_id: str, *, auth_data: AuthData, body: MemberInvitePayload):
inviter = _load_account(auth_data.account_id)
tenant = _load_tenant(workspace_id)
@ -213,9 +198,9 @@ class WorkspaceMembersApi(Resource):
try:
token = RegisterService.invite_new_member(
tenant=tenant,
email=payload.email,
email=body.email,
language=None,
role=payload.role,
role=body.role,
inviter=inviter,
)
except AccountAlreadyInTenantError as exc:
@ -225,7 +210,7 @@ class WorkspaceMembersApi(Resource):
except AccountRegisterError as exc:
raise BadRequest(str(exc))
normalized_email = payload.email.lower()
normalized_email = body.email.lower()
member = AccountService.get_account_by_email_with_case_fallback(normalized_email)
if member is None:
# invite_new_member just created or fetched this account.
@ -235,11 +220,11 @@ class WorkspaceMembersApi(Resource):
invite_url = f"{dify_config.CONSOLE_WEB_URL}/activate?email={encoded_email}&token={token}"
return MemberInviteResponse(
email=normalized_email,
role=payload.role,
role=body.role,
member_id=str(member.id),
invite_url=invite_url,
tenant_id=str(tenant.id),
).model_dump(mode="json"), 201
)
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>")
@ -251,12 +236,12 @@ class WorkspaceMemberApi(Resource):
400 per the spec, with the service's message preserved.
"""
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
@returns(200, MemberActionResponse, description="Member removed")
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
operator = _load_account(auth_data.account_id)
tenant = _load_tenant(workspace_id)
@ -273,7 +258,7 @@ class WorkspaceMemberApi(Resource):
except MemberNotInTenantError as exc:
raise NotFound(str(exc))
return MemberActionResponse().model_dump(mode="json"), 200
return MemberActionResponse()
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>/role")
@ -284,15 +269,14 @@ class WorkspaceMemberRoleApi(Resource):
standing owner (service NoPermissionError 400, per spec).
"""
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
@auth_router.guard_workspace(
scope=Scope.WORKSPACE_WRITE,
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}),
)
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
payload = _validate_body(MemberRoleUpdatePayload)
@returns(200, MemberActionResponse, description="Role updated")
@accepts(body=MemberRoleUpdatePayload)
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData, body: MemberRoleUpdatePayload):
operator = _load_account(auth_data.account_id)
tenant = _load_tenant(workspace_id)
member = AccountService.get_account_by_id(db.session, member_id)
@ -300,7 +284,7 @@ class WorkspaceMemberRoleApi(Resource):
raise NotFound("member not found")
try:
TenantService.update_member_role(tenant, member, payload.role, operator)
TenantService.update_member_role(tenant, member, body.role, operator)
except CannotOperateSelfError as exc:
raise BadRequest(str(exc))
except NoPermissionError as exc:
@ -310,7 +294,7 @@ class WorkspaceMemberRoleApi(Resource):
except RoleAlreadyAssignedError as exc:
raise BadRequest(str(exc))
return MemberActionResponse().model_dump(mode="json"), 200
return MemberActionResponse()
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:

View File

@ -299,6 +299,15 @@ Upload a file to use as an input variable when running the app
### /permitted-external-apps
#### GET
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| limit | query | | No | integer |
| mode | query | | No | string |
| name | query | | No | string |
| page | query | | No | integer |
##### Responses
| Code | Description | Schema |

View File

@ -94,4 +94,4 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo
queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1")
graph_instance.send_stop_command.assert_called_once_with("task-1")
assert result == {"result": "success"}
assert result == ({"result": "success"}, 200)

View File

@ -0,0 +1,210 @@
"""Unit tests for the @accepts / @returns contract decorators.
Exercises the decorators in isolation (not through a real controller): a plain
view function decorated with @accepts/@returns, driven inside a request context.
"""
from functools import wraps
import pytest
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import UnprocessableEntity
from controllers.common.schema import register_response_schema_model, register_schema_model
from controllers.openapi import openapi_ns
from controllers.openapi._contract import accepts, returns
class ContractQuery(BaseModel):
model_config = ConfigDict(extra="forbid")
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=100)
class ContractBody(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
class ContractResp(BaseModel):
value: int
@pytest.fixture(autouse=True, scope="module")
def _register_contract_test_models():
# Register for @accepts(body=)/@returns name lookups; drop on teardown so these
# test-only models don't leak into the shared openapi_ns / generated spec.
register_schema_model(openapi_ns, ContractBody)
register_response_schema_model(openapi_ns, ContractResp)
yield
openapi_ns.models.pop(ContractBody.__name__, None)
openapi_ns.models.pop(ContractResp.__name__, None)
def _guard_like(view):
"""Stand-in for ``@auth_router.guard`` — an outermost @wraps layer."""
@wraps(view)
def wrapper(*args, **kwargs):
return view(*args, **kwargs)
return wrapper
def test_accepts_injects_validated_query(app):
@accepts(query=ContractQuery)
def view(*, query):
return query
with app.test_request_context("/?page=3&limit=5"):
result = view()
assert isinstance(result, ContractQuery)
assert result.page == 3
assert result.limit == 5
def test_accepts_query_uses_defaults_when_absent(app):
@accepts(query=ContractQuery)
def view(*, query):
return query
with app.test_request_context("/"):
result = view()
assert result.page == 1
assert result.limit == 20
@pytest.mark.parametrize("query_string", ["page=0", "limit=999", "page=abc", "unknown=1"])
def test_accepts_rejects_invalid_query_with_422(app, query_string):
@accepts(query=ContractQuery)
def view(*, query):
return query
with app.test_request_context(f"/?{query_string}"):
with pytest.raises(UnprocessableEntity):
view()
def test_accepts_validation_error_is_sanitized_and_structured(app):
"""422 body is structured and leaks neither the pydantic docs url nor the user input."""
@accepts(body=ContractBody)
def view(*, body):
return body
with app.test_request_context("/", method="POST", json={"secret": "leak-me"}):
with pytest.raises(UnprocessableEntity) as exc_info:
view()
data = exc_info.value.data
assert data["message"] == "Request validation failed"
assert isinstance(data["errors"], list)
assert data["errors"]
for err in data["errors"]:
assert {"type", "loc", "msg"} <= err.keys()
assert "url" not in err
assert "input" not in err
assert "leak-me" not in str(data)
def test_accepts_injects_validated_body(app):
@accepts(body=ContractBody)
def view(*, body):
return body
with app.test_request_context("/", method="POST", json={"name": "x"}):
result = view()
assert isinstance(result, ContractBody)
assert result.name == "x"
def test_accepts_rejects_invalid_body_with_422(app):
@accepts(body=ContractBody)
def view(*, body):
return body
with app.test_request_context("/", method="POST", json={"wrong": 1}):
with pytest.raises(UnprocessableEntity):
view()
def test_returns_serializes_model_with_decorator_status(app):
@returns(200, ContractResp)
def view():
return ContractResp(value=7)
with app.test_request_context("/"):
body, status = view()
assert status == 200
assert body == {"value": 7}
def test_returns_serializes_model_in_tuple_and_honors_status(app):
@returns(200, ContractResp)
def view():
return ContractResp(value=9), 201
with app.test_request_context("/"):
body, status = view()
assert status == 201
assert body == {"value": 9}
def test_returns_passes_through_non_model(app):
sentinel = object()
@returns(200, ContractResp)
def view():
return sentinel
with app.test_request_context("/"):
result = view()
assert result is sentinel
def test_returns_serializes_model_in_three_tuple_with_headers(app):
"""A (model, status, headers) tuple keeps its trailing status/headers intact."""
@returns(200, ContractResp)
def view():
return ContractResp(value=3), 202, {"X-Test": "1"}
with app.test_request_context("/"):
body, status, headers = view()
assert body == {"value": 3}
assert status == 202
assert headers == {"X-Test": "1"}
# Swagger metadata (read off __apidoc__) must survive @wraps up through the guard layer.
def test_accepts_returns_emit_apidoc_through_guard_stack():
@_guard_like
@returns(200, ContractResp)
@accepts(query=ContractQuery)
def view(*, query):
return ContractResp(value=1)
apidoc = getattr(view, "__apidoc__", {})
assert "page" in apidoc.get("params", {}) # from @accepts(query=)
assert "200" in apidoc.get("responses", {}) # from @returns (flask_restx keys by str code)
def test_accepts_body_emits_expect_through_guard_stack():
@_guard_like
@accepts(body=ContractBody)
def view(*, body):
return body
apidoc = getattr(view, "__apidoc__", {})
assert apidoc.get("expect") # body schema advertised via @openapi_ns.expect

View File

@ -11,7 +11,7 @@ from unittest.mock import Mock
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import NotFound, UnprocessableEntity
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
@ -233,3 +233,24 @@ class TestOpenApiHumanInputFormPost:
submission_end_user_id="eu-7",
)
assert result == ({}, 200)
def test_post_rejects_invalid_body_with_422(self, app: Flask, bypass_pipeline):
"""Malformed body → 422 via @accepts (was an unmapped pydantic error → 500)."""
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-42")
with app.test_request_context(
"/openapi/v1/apps/app-1/form/human_input/tok-1",
method="POST",
json={"inputs": {"field1": "val"}}, # missing required "action"
):
with pytest.raises(UnprocessableEntity):
api.post.__wrapped__(
api,
app_id="app-1",
form_token="tok-1",
auth_data=_make_auth_data(app_model, caller, "account"),
)

View File

@ -29,7 +29,7 @@ import pytest
from flask import Flask
from flask.views import MethodView
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, UnprocessableEntity
from controllers.openapi import bp as openapi_bp
from controllers.openapi._models import MemberInvitePayload, MemberRoleUpdatePayload
@ -198,7 +198,7 @@ def test_member_role_route_registered(openapi_app: Flask):
# ---------------------------------------------------------------------------
# Payload validation lands at 400
# Payload validation lands at 422 (unified via @accepts)
# ---------------------------------------------------------------------------
@ -227,18 +227,38 @@ def test_role_payload_rejects_extra_field():
MemberRoleUpdatePayload.model_validate({"role": "normal", "extra": "x"})
def test_validate_body_helper_maps_validation_error_to_400(app, monkeypatch):
"""`_validate_body` is the centralized 400-mapper for invalid request bodies."""
from controllers.openapi.workspaces import _validate_body
def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline):
"""Invalid invite body → 422 via @accepts (was 400 through _validate_body)."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMembersApi()
with app.test_request_context(
"/openapi/v1/workspaces/ws-1/members",
f"/openapi/v1/workspaces/{ws_id}/members",
method="POST",
data=json.dumps({"email": "u@example.com", "role": "owner"}),
data=json.dumps({"email": "u@example.com", "role": "owner"}), # owner is not invite-assignable
content_type="application/json",
):
with pytest.raises(BadRequest):
_validate_body(MemberInvitePayload)
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(UnprocessableEntity):
api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))
def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline):
"""Invalid role-update body surfaces as 422 through @accepts (was 400)."""
ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMemberRoleApi()
with app.test_request_context(
f"/openapi/v1/workspaces/{ws_id}/members/{member_id}/role",
method="PUT",
data=json.dumps({"role": "owner"}), # closed enum rejects owner
content_type="application/json",
):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(UnprocessableEntity):
api.put.__wrapped__(api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id))
# ---------------------------------------------------------------------------
@ -384,7 +404,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa
def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypatch):
"""Strict (`extra='forbid'`) — typos like `?pg=2` surface as 400."""
"""Strict (`extra='forbid'`) — typos like `?pg=2` surface as 422 (unified via @accepts)."""
ws_id = str(uuid.uuid4())
acct_id = uuid.uuid4()
api = WorkspaceMembersApi()
@ -395,7 +415,7 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa
with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"):
_seed(_auth_ctx(account_id=acct_id))
with pytest.raises(BadRequest):
with pytest.raises(UnprocessableEntity):
api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id))

View File

@ -24,6 +24,7 @@ import {
zGetHealthResponse,
zGetOauthDeviceLookupQuery,
zGetOauthDeviceLookupResponse,
zGetPermittedExternalAppsQuery,
zGetPermittedExternalAppsResponse,
zGetVersionResponse,
zGetWorkspacesByWorkspaceIdMembersPath,
@ -438,6 +439,7 @@ export const get10 = oc
path: '/permitted-external-apps',
tags: ['openapi'],
})
.input(z.object({ query: zGetPermittedExternalAppsQuery.optional() }))
.output(zGetPermittedExternalAppsResponse)
export const permittedExternalApps = {

View File

@ -656,7 +656,12 @@ export type PostOauthDeviceTokenResponse
export type GetPermittedExternalAppsData = {
body?: never
path?: never
query?: never
query?: {
limit?: number
mode?: string
name?: string
page?: number
}
url: '/permitted-external-apps'
}

View File

@ -638,6 +638,13 @@ export const zPostOauthDeviceTokenBody = zDevicePollRequest
*/
export const zPostOauthDeviceTokenResponse = z.record(z.string(), z.unknown())
export const zGetPermittedExternalAppsQuery = z.object({
limit: z.int().gte(1).lte(200).optional().default(20),
mode: z.string().optional(),
name: z.string().max(200).optional(),
page: z.int().gte(1).optional().default(1),
})
/**
* Permitted external apps list
*/