diff --git a/api/controllers/openapi/_contract.py b/api/controllers/openapi/_contract.py new file mode 100644 index 0000000000..0979b01a35 --- /dev/null +++ b/api/controllers/openapi/_contract.py @@ -0,0 +1,81 @@ +"""Request/response contract decorators for the openapi controllers. + +``@accepts`` and ``@returns`` own one slice of the contract from a single model +reference — emitting the Swagger schema AND doing the runtime validation/ +serialisation — so the advertised and enforced contracts can't drift. Validation +failures map to a single shape: 422. + +They must sit BELOW ``@auth_router.guard`` so auth runs before validation and the +``view.__wrapped__`` unit-test seam unwraps exactly the guard layer. +""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import Any + +from flask import request +from flask_restx import abort +from pydantic import BaseModel, ValidationError + +from controllers.common.schema import query_params_from_model, query_params_from_request +from controllers.openapi import openapi_ns + + +def accepts(*, query: type[BaseModel] | None = None, body: type[BaseModel] | None = None) -> Callable: + """Validate ``query``/``body`` against the models and inject them as keyword-only kwargs. + + Emits the matching Swagger schema from the same models, so doc and enforcement + stay in lockstep. + """ + + def decorator(view: Callable) -> Callable: + @wraps(view) + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + if query is not None: + kwargs["query"] = query_params_from_request(query) + if body is not None: + kwargs["body"] = body.model_validate(request.get_json(silent=True) or {}) + except ValidationError as exc: + # Sanitized 422 — no pydantic `url` (version) or `input` (user payload) leak. + abort( + 422, + message="Request validation failed", + errors=exc.errors(include_url=False, include_input=False, include_context=False), + ) + return view(*args, **kwargs) + + if query is not None: + openapi_ns.doc(params=query_params_from_model(query))(wrapper) + if body is not None: + openapi_ns.expect(openapi_ns.models[body.__name__])(wrapper) + return wrapper + + return decorator + + +def returns(code: int, model: type[BaseModel], description: str | None = None) -> Callable: + """Serialise the handler's returned model and emit the response schema. + + Accepts a ``BaseModel`` (serialised with ``code``) or a ``(model, status[, headers])`` + tuple (status/headers honoured). Other returns — a bare ``(dict, status)``, an SSE + ``Response`` — pass through untouched. + """ + + def decorator(view: Callable) -> Callable: + @wraps(view) + def wrapper(*args: Any, **kwargs: Any) -> Any: + result = view(*args, **kwargs) + if isinstance(result, BaseModel): + return result.model_dump(mode="json"), code + if isinstance(result, tuple) and result and isinstance(result[0], BaseModel): + payload, *rest = result + return (payload.model_dump(mode="json"), *rest) + return result + + openapi_ns.response(code, description or model.__name__, openapi_ns.models[model.__name__])(wrapper) + return wrapper + + return decorator diff --git a/api/controllers/openapi/_meta.py b/api/controllers/openapi/_meta.py index e1c380bf55..c49f7526ac 100644 --- a/api/controllers/openapi/_meta.py +++ b/api/controllers/openapi/_meta.py @@ -9,15 +9,16 @@ from flask_restx import Resource from configs import dify_config from controllers.openapi import openapi_ns +from controllers.openapi._contract import returns from controllers.openapi._models import ServerVersionResponse @openapi_ns.route("/_version") class VersionApi(Resource): - @openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__]) + @returns(200, ServerVersionResponse, description="Server version") def get(self): edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED" return ServerVersionResponse( version=dify_config.project.version, edition=edition, - ).model_dump(mode="json") + ) diff --git a/api/controllers/openapi/account.py b/api/controllers/openapi/account.py index 05223a97e6..8ad0b02f4a 100644 --- a/api/controllers/openapi/account.py +++ b/api/controllers/openapi/account.py @@ -2,17 +2,14 @@ from __future__ import annotations from datetime import UTC, datetime -from flask import request from flask_restx import Resource -from pydantic import ValidationError -from werkzeug.exceptions import NotFound, UnprocessableEntity +from werkzeug.exceptions import NotFound -from controllers.common.schema import query_params_from_model from controllers.openapi import openapi_ns +from controllers.openapi._contract import accepts, returns from controllers.openapi._models import ( AccountPayload, AccountResponse, - PaginationEnvelope, RevokeResponse, SessionListQuery, SessionListResponse, @@ -42,8 +39,8 @@ from services.oauth_device_flow import ( @openapi_ns.route("/account") class AccountApi(Resource): - @openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__]) @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @returns(200, AccountResponse, description="Account info") def get(self, *, auth_data: AuthData): enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}") @@ -58,31 +55,27 @@ class AccountApi(Resource): account=_account_payload(account) if account else None, workspaces=[_workspace_payload(m) for m in memberships], default_workspace_id=default_ws_id, - ).model_dump(mode="json") + ) @openapi_ns.route("/account/sessions/self") class AccountSessionsSelfApi(Resource): - @openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__]) @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @returns(200, RevokeResponse, description="Session revoked") def delete(self, *, auth_data: AuthData): revoke_oauth_token(db.session, redis_client, str(auth_data.token_id)) - return RevokeResponse(status="revoked").model_dump(mode="json"), 200 + return RevokeResponse(status="revoked") @openapi_ns.route("/account/sessions") class AccountSessionsApi(Resource): - @openapi_ns.doc(params=query_params_from_model(SessionListQuery)) - @openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__]) @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) - def get(self, *, auth_data: AuthData): - # Validate page/limit through the same model the contract advertises (extra='forbid', - # page>=1, 1<=limit<=MAX_PAGE_LIMIT) so the server actually enforces those bounds rather - # than silently coercing (e.g. page=0 -> empty slice). Mirrors AppDescribeQuery. - try: - query = SessionListQuery.model_validate(request.args.to_dict(flat=True)) - except ValidationError as exc: - raise UnprocessableEntity(exc.json()) + @returns(200, SessionListResponse, description="Session list") + @accepts(query=SessionListQuery) + def get(self, *, auth_data: AuthData, query: SessionListQuery): + # SessionListQuery enforces the advertised bounds (extra='forbid', page>=1, + # 1<=limit<=MAX_PAGE_LIMIT) so the server rejects out-of-range paging rather + # than silently coercing (e.g. page=0 -> empty slice). ctx = get_auth_ctx() now = datetime.now(UTC) page = query.page @@ -106,16 +99,19 @@ class AccountSessionsApi(Resource): for r in sliced ] - return ( - PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"), - 200, + return SessionListResponse( + page=page, + limit=limit, + total=total, + has_more=page * limit < total, + data=items, ) @openapi_ns.route("/account/sessions/") class AccountSessionByIdApi(Resource): - @openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__]) @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @returns(200, RevokeResponse, description="Session revoked") def delete(self, session_id: str, *, auth_data: AuthData): ctx = get_auth_ctx() @@ -125,7 +121,7 @@ class AccountSessionByIdApi(Resource): raise NotFound("session not found") revoke_oauth_token(db.session, redis_client, session_id) - return RevokeResponse(status="revoked").model_dump(mode="json"), 200 + return RevokeResponse(status="revoked") def _iso(dt: datetime | None) -> str | None: diff --git a/api/controllers/openapi/app_run.py b/api/controllers/openapi/app_run.py index 7b9030362a..d801f5183f 100644 --- a/api/controllers/openapi/app_run.py +++ b/api/controllers/openapi/app_run.py @@ -7,14 +7,13 @@ from collections.abc import Callable, Iterator from contextlib import contextmanager from typing import Any -from flask import request from flask_restx import Resource -from pydantic import ValidationError from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity import services from controllers.openapi import openapi_ns from controllers.openapi._audit import emit_app_run +from controllers.openapi._contract import accepts, returns from controllers.openapi._models import AppRunRequest, TaskStopResponse from controllers.openapi.auth.composition import auth_router from controllers.openapi.auth.data import AuthData @@ -123,23 +122,18 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = { @openapi_ns.route("/apps//run") class AppRunApi(Resource): - @openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__]) - @openapi_ns.response(200, "Run result (SSE stream)") @auth_router.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, *, auth_data: AuthData): + @openapi_ns.response(200, "Run result (SSE stream)") + @accepts(body=AppRunRequest) + def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest): app_model, caller, caller_kind = auth_data.require_app_context() - body = request.get_json(silent=True) or {} - try: - payload = AppRunRequest.model_validate(body) - except ValidationError as exc: - raise UnprocessableEntity(exc.json()) handler = _DISPATCH.get(app_model.mode) if handler is None: raise UnprocessableEntity("mode_not_runnable") try: - stream_obj = handler(app_model, caller, payload) + stream_obj = handler(app_model, caller, body) except HTTPException: raise except Exception: @@ -159,10 +153,10 @@ class AppRunApi(Resource): @openapi_ns.route("/apps//tasks//stop") class AppRunTaskStopApi(Resource): - @openapi_ns.response(200, "Task stopped", openapi_ns.models[TaskStopResponse.__name__]) @auth_router.guard(scope=Scope.APPS_RUN) + @returns(200, TaskStopResponse, description="Task stopped") def post(self, app_id: str, task_id: str, *, auth_data: AuthData): app_model, caller, caller_kind = auth_data.require_app_context() AppQueueManager.set_stop_flag_no_user_check(task_id) GraphEngineManager(redis_client).send_stop_command(task_id) - return {"result": "success"} + return TaskStopResponse(result="success") diff --git a/api/controllers/openapi/apps.py b/api/controllers/openapi/apps.py index 9520d6b097..84b1610d5f 100644 --- a/api/controllers/openapi/apps.py +++ b/api/controllers/openapi/apps.py @@ -5,14 +5,12 @@ from __future__ import annotations import uuid as _uuid from typing import Any, cast -from flask import request from flask_restx import Resource -from pydantic import ValidationError from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity from controllers.common.fields import Parameters -from controllers.common.schema import query_params_from_model from controllers.openapi import openapi_ns +from controllers.openapi._contract import accepts, returns from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config from controllers.openapi._models import ( AppDescribeInfo, @@ -88,15 +86,11 @@ def parameters_payload(app: App) -> dict: @openapi_ns.route("/apps//describe") class AppDescribeApi(AppReadResource): - @openapi_ns.doc(params=query_params_from_model(AppDescribeQuery)) - @openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__]) @auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) - def get(self, app_id: str, *, auth_data: AuthData): - try: - query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True)) - except ValidationError as exc: - raise UnprocessableEntity(exc.json()) - + @returns(200, AppDescribeResponse, description="App description") + @accepts(query=AppDescribeQuery) + def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery): + # describe is UUID-only (workspace_id query param dropped in #37212). app = self._load(app_id) requested = query.fields @@ -133,35 +127,22 @@ class AppDescribeApi(AppReadResource): except AppUnavailableError: input_schema = dict(EMPTY_INPUT_SCHEMA) - return ( - AppDescribeResponse( - info=info, - parameters=parameters, - input_schema=input_schema, - ).model_dump(mode="json", exclude_none=False), - 200, + return AppDescribeResponse( + info=info, + parameters=parameters, + input_schema=input_schema, ) @openapi_ns.route("/apps") class AppListApi(Resource): - @openapi_ns.doc(params=query_params_from_model(AppListQuery)) - @openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__]) @auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) - def get(self, *, auth_data: AuthData): - try: - query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True)) - except ValidationError as exc: - raise UnprocessableEntity(exc.json()) - + @returns(200, AppListResponse, description="App list") + @accepts(query=AppListQuery) + def get(self, *, auth_data: AuthData, query: AppListQuery): workspace_id = query.workspace_id - empty = ( - AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump( - mode="json" - ), - 200, - ) + empty = AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]) if query.name: try: @@ -189,7 +170,7 @@ class AppListApi(Resource): workspace_name=tenant_name, ) env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item]) - return env.model_dump(mode="json"), 200 + return env tag_ids: list[str] | None = None if query.tag: @@ -240,4 +221,4 @@ class AppListApi(Resource): has_more=query.page * query.limit < cast(int, pagination.total), data=items, ) - return env.model_dump(mode="json"), 200 + return env diff --git a/api/controllers/openapi/apps_permitted_external.py b/api/controllers/openapi/apps_permitted_external.py index f86fd34a19..0e889a2951 100644 --- a/api/controllers/openapi/apps_permitted_external.py +++ b/api/controllers/openapi/apps_permitted_external.py @@ -7,12 +7,10 @@ EE blueprint chain so this module is unreachable there. from __future__ import annotations -from flask import request from flask_restx import Resource -from pydantic import ValidationError -from werkzeug.exceptions import UnprocessableEntity from controllers.openapi import openapi_ns +from controllers.openapi._contract import accepts, returns from controllers.openapi._models import ( AppListRow, PermittedExternalAppsListQuery, @@ -30,20 +28,14 @@ from services.enterprise.app_permitted_service import list_permitted_apps @openapi_ns.route("/permitted-external-apps") class PermittedExternalAppsListApi(Resource): - @openapi_ns.response( - 200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__] - ) @auth_router.guard( scope=Scope.APPS_READ_PERMITTED_EXTERNAL, allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}), edition=frozenset({Edition.EE}), ) - def get(self, *, auth_data: AuthData): - try: - query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True)) - except ValidationError as exc: - raise UnprocessableEntity(exc.json()) - + @returns(200, PermittedExternalAppsListResponse, description="Permitted external apps list") + @accepts(query=PermittedExternalAppsListQuery) + def get(self, *, auth_data: AuthData, query: PermittedExternalAppsListQuery): page_result = list_permitted_apps( page=query.page, limit=query.limit, @@ -55,7 +47,7 @@ class PermittedExternalAppsListApi(Resource): env = PermittedExternalAppsListResponse( page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[] ) - return env.model_dump(mode="json"), 200 + return env apps_by_id: dict[str, App] = { str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids) @@ -89,4 +81,4 @@ class PermittedExternalAppsListApi(Resource): has_more=query.page * query.limit < page_result.total, data=items, ) - return env.model_dump(mode="json"), 200 + return env diff --git a/api/controllers/openapi/files.py b/api/controllers/openapi/files.py index 1a2c16abf9..e77e4bc302 100644 --- a/api/controllers/openapi/files.py +++ b/api/controllers/openapi/files.py @@ -17,6 +17,7 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from controllers.openapi import openapi_ns +from controllers.openapi._contract import returns from controllers.openapi.auth.composition import auth_router from controllers.openapi.auth.data import AuthData from extensions.ext_database import db @@ -38,8 +39,8 @@ class AppFileUploadApi(Resource): 415: "Unsupported file type or blocked extension", } ) - @openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__]) @auth_router.guard(scope=Scope.APPS_RUN) + @returns(HTTPStatus.CREATED, FileResponse, description="File uploaded") def post(self, app_id: str, *, auth_data: AuthData): app_model, caller, _ = auth_data.require_app_context() if "file" not in request.files: @@ -69,5 +70,4 @@ class AppFileUploadApi(Resource): except services.errors.file.BlockedFileExtensionError as exc: raise BlockedFileExtensionError(exc.description) - response = FileResponse.model_validate(upload_file, from_attributes=True) - return response.model_dump(mode="json"), 201 + return FileResponse.model_validate(upload_file, from_attributes=True) diff --git a/api/controllers/openapi/human_input_form.py b/api/controllers/openapi/human_input_form.py index 6b9d4a711e..e04dc8a1af 100644 --- a/api/controllers/openapi/human_input_form.py +++ b/api/controllers/openapi/human_input_form.py @@ -10,13 +10,14 @@ from __future__ import annotations import json import logging -from flask import Response, request +from flask import Response from flask_restx import Resource from werkzeug.exceptions import BadRequest, NotFound from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values from controllers.common.schema import register_schema_models from controllers.openapi import openapi_ns +from controllers.openapi._contract import accepts, returns from controllers.openapi._models import FormSubmitResponse from controllers.openapi.auth.composition import auth_router from controllers.openapi.auth.data import AuthData @@ -70,12 +71,11 @@ class OpenApiWorkflowHumanInputFormApi(Resource): service.ensure_form_active(form) return _jsonify_form_definition(form) - @openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__]) - @openapi_ns.response(200, "Form submitted", openapi_ns.models[FormSubmitResponse.__name__]) @auth_router.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, form_token: str, *, auth_data: AuthData): + @returns(200, FormSubmitResponse, description="Form submitted") + @accepts(body=HumanInputFormSubmitPayload) + def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload): app_model, caller, caller_kind = auth_data.require_app_context() - payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {}) service = HumanInputService(db.engine) form = service.get_form_by_token(form_token) @@ -100,12 +100,12 @@ class OpenApiWorkflowHumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=form.recipient_type, form_token=form_token, - selected_action_id=payload.action, - form_data=payload.inputs, + selected_action_id=body.action, + form_data=body.inputs, submission_user_id=submission_user_id, submission_end_user_id=submission_end_user_id, ) except FormNotFoundError: raise NotFound("Form not found") - return {}, 200 + return FormSubmitResponse() diff --git a/api/controllers/openapi/index.py b/api/controllers/openapi/index.py index ae1780aecd..97e9c6e75d 100644 --- a/api/controllers/openapi/index.py +++ b/api/controllers/openapi/index.py @@ -1,11 +1,12 @@ from flask_restx import Resource from controllers.openapi import openapi_ns +from controllers.openapi._contract import returns from controllers.openapi._models import HealthResponse @openapi_ns.route("/_health") class HealthApi(Resource): - @openapi_ns.response(200, "Health check", openapi_ns.models[HealthResponse.__name__]) + @returns(200, HealthResponse, description="Health check") def get(self): - return {"ok": True} + return HealthResponse(ok=True) diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index 7d01db8dc1..902337703a 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -14,14 +14,13 @@ from __future__ import annotations from itertools import starmap from urllib import parse -from flask import jsonify, make_response, request +from flask import jsonify, make_response from flask_restx import Resource -from pydantic import BaseModel, ValidationError from werkzeug.exceptions import BadRequest, Forbidden, NotFound from configs import dify_config -from controllers.common.schema import query_params_from_model from controllers.openapi import openapi_ns +from controllers.openapi._contract import accepts, returns from controllers.openapi._models import ( MemberActionResponse, MemberInvitePayload, @@ -53,14 +52,6 @@ from services.errors.account import ( from services.feature_service import FeatureService -def _validate_body[M: BaseModel](model: type[M]) -> M: - body = request.get_json(silent=True) or {} - try: - return model.model_validate(body) - except ValidationError as exc: - raise BadRequest(str(exc)) - - def _member_response(account: Account) -> MemberResponse: return MemberResponse( id=str(account.id), @@ -118,18 +109,18 @@ def _check_member_invite_quota(tenant_id: str) -> None: @openapi_ns.route("/workspaces") class WorkspacesApi(Resource): - @openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__]) @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @returns(200, WorkspaceListResponse, description="Workspace list") def get(self, *, auth_data: AuthData): rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id)) - return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200 + return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))) @openapi_ns.route("/workspaces/") class WorkspaceByIdApi(Resource): - @openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__]) @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @returns(200, WorkspaceDetailResponse, description="Workspace detail") def get(self, workspace_id: str, *, auth_data: AuthData): row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id) # 404 (not 403) on non-member so workspace IDs don't leak across tenants. @@ -137,7 +128,7 @@ class WorkspaceByIdApi(Resource): raise NotFound("workspace not found") tenant, membership = row - return _workspace_detail(tenant, membership).model_dump(mode="json"), 200 + return _workspace_detail(tenant, membership) @openapi_ns.route("/workspaces//switch") @@ -149,8 +140,8 @@ class WorkspaceSwitchApi(Resource): that ``hosts.yml`` never diverges from the server's ``current`` state. """ - @openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__]) @auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + @returns(200, WorkspaceDetailResponse, description="Workspace detail") def post(self, workspace_id: str, *, auth_data: AuthData): account = _load_account(auth_data.account_id) @@ -163,7 +154,7 @@ class WorkspaceSwitchApi(Resource): if row is None: raise NotFound("workspace not found") tenant, membership = row - return _workspace_detail(tenant, membership).model_dump(mode="json"), 200 + return _workspace_detail(tenant, membership) @openapi_ns.route("/workspaces//members") @@ -174,15 +165,10 @@ class WorkspaceMembersApi(Resource): assigned through invite (ownership transfer is console-only). """ - @openapi_ns.doc(params=query_params_from_model(MemberListQuery)) - @openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__]) @auth_router.guard_workspace(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) - def get(self, workspace_id: str, *, auth_data: AuthData): - try: - query = MemberListQuery.model_validate(request.args.to_dict(flat=True)) - except ValidationError as exc: - raise BadRequest(str(exc)) - + @returns(200, MemberListResponse, description="Member list") + @accepts(query=MemberListQuery) + def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery): tenant = _load_tenant(workspace_id) members = TenantService.get_tenant_members(tenant) total = len(members) @@ -194,17 +180,16 @@ class WorkspaceMembersApi(Resource): total=total, has_more=query.page * query.limit < total, data=[_member_response(m) for m in page_items], - ).model_dump(mode="json"), 200 + ) - @openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__]) - @openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__]) @auth_router.guard_workspace( scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}), allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}), ) - def post(self, workspace_id: str, *, auth_data: AuthData): - payload = _validate_body(MemberInvitePayload) + @returns(201, MemberInviteResponse, description="Member invited") + @accepts(body=MemberInvitePayload) + def post(self, workspace_id: str, *, auth_data: AuthData, body: MemberInvitePayload): inviter = _load_account(auth_data.account_id) tenant = _load_tenant(workspace_id) @@ -213,9 +198,9 @@ class WorkspaceMembersApi(Resource): try: token = RegisterService.invite_new_member( tenant=tenant, - email=payload.email, + email=body.email, language=None, - role=payload.role, + role=body.role, inviter=inviter, ) except AccountAlreadyInTenantError as exc: @@ -225,7 +210,7 @@ class WorkspaceMembersApi(Resource): except AccountRegisterError as exc: raise BadRequest(str(exc)) - normalized_email = payload.email.lower() + normalized_email = body.email.lower() member = AccountService.get_account_by_email_with_case_fallback(normalized_email) if member is None: # invite_new_member just created or fetched this account. @@ -235,11 +220,11 @@ class WorkspaceMembersApi(Resource): invite_url = f"{dify_config.CONSOLE_WEB_URL}/activate?email={encoded_email}&token={token}" return MemberInviteResponse( email=normalized_email, - role=payload.role, + role=body.role, member_id=str(member.id), invite_url=invite_url, tenant_id=str(tenant.id), - ).model_dump(mode="json"), 201 + ) @openapi_ns.route("/workspaces//members/") @@ -251,12 +236,12 @@ class WorkspaceMemberApi(Resource): 400 per the spec, with the service's message preserved. """ - @openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__]) @auth_router.guard_workspace( scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}), allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}), ) + @returns(200, MemberActionResponse, description="Member removed") def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData): operator = _load_account(auth_data.account_id) tenant = _load_tenant(workspace_id) @@ -273,7 +258,7 @@ class WorkspaceMemberApi(Resource): except MemberNotInTenantError as exc: raise NotFound(str(exc)) - return MemberActionResponse().model_dump(mode="json"), 200 + return MemberActionResponse() @openapi_ns.route("/workspaces//members//role") @@ -284,15 +269,14 @@ class WorkspaceMemberRoleApi(Resource): standing owner (service NoPermissionError → 400, per spec). """ - @openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__]) - @openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__]) @auth_router.guard_workspace( scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}), allowed_roles=frozenset({TenantAccountRole.OWNER, TenantAccountRole.ADMIN}), ) - def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData): - payload = _validate_body(MemberRoleUpdatePayload) + @returns(200, MemberActionResponse, description="Role updated") + @accepts(body=MemberRoleUpdatePayload) + def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData, body: MemberRoleUpdatePayload): operator = _load_account(auth_data.account_id) tenant = _load_tenant(workspace_id) member = AccountService.get_account_by_id(db.session, member_id) @@ -300,7 +284,7 @@ class WorkspaceMemberRoleApi(Resource): raise NotFound("member not found") try: - TenantService.update_member_role(tenant, member, payload.role, operator) + TenantService.update_member_role(tenant, member, body.role, operator) except CannotOperateSelfError as exc: raise BadRequest(str(exc)) except NoPermissionError as exc: @@ -310,7 +294,7 @@ class WorkspaceMemberRoleApi(Resource): except RoleAlreadyAssignedError as exc: raise BadRequest(str(exc)) - return MemberActionResponse().model_dump(mode="json"), 200 + return MemberActionResponse() def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse: diff --git a/api/openapi/markdown/openapi-swagger.md b/api/openapi/markdown/openapi-swagger.md index bbeb9f7f2a..8c19adb203 100644 --- a/api/openapi/markdown/openapi-swagger.md +++ b/api/openapi/markdown/openapi-swagger.md @@ -299,6 +299,15 @@ Upload a file to use as an input variable when running the app ### /permitted-external-apps #### GET +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| limit | query | | No | integer | +| mode | query | | No | string | +| name | query | | No | string | +| page | query | | No | integer | + ##### Responses | Code | Description | Schema | diff --git a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py index 8933533af0..0dbb595ba1 100644 --- a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py +++ b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py @@ -94,4 +94,4 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1") graph_instance.send_stop_command.assert_called_once_with("task-1") - assert result == {"result": "success"} + assert result == ({"result": "success"}, 200) diff --git a/api/tests/unit_tests/controllers/openapi/test_contract.py b/api/tests/unit_tests/controllers/openapi/test_contract.py new file mode 100644 index 0000000000..990437e37f --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_contract.py @@ -0,0 +1,210 @@ +"""Unit tests for the @accepts / @returns contract decorators. + +Exercises the decorators in isolation (not through a real controller): a plain +view function decorated with @accepts/@returns, driven inside a request context. +""" + +from functools import wraps + +import pytest +from pydantic import BaseModel, ConfigDict, Field +from werkzeug.exceptions import UnprocessableEntity + +from controllers.common.schema import register_response_schema_model, register_schema_model +from controllers.openapi import openapi_ns +from controllers.openapi._contract import accepts, returns + + +class ContractQuery(BaseModel): + model_config = ConfigDict(extra="forbid") + + page: int = Field(1, ge=1) + limit: int = Field(20, ge=1, le=100) + + +class ContractBody(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + + +class ContractResp(BaseModel): + value: int + + +@pytest.fixture(autouse=True, scope="module") +def _register_contract_test_models(): + # Register for @accepts(body=)/@returns name lookups; drop on teardown so these + # test-only models don't leak into the shared openapi_ns / generated spec. + register_schema_model(openapi_ns, ContractBody) + register_response_schema_model(openapi_ns, ContractResp) + yield + openapi_ns.models.pop(ContractBody.__name__, None) + openapi_ns.models.pop(ContractResp.__name__, None) + + +def _guard_like(view): + """Stand-in for ``@auth_router.guard`` — an outermost @wraps layer.""" + + @wraps(view) + def wrapper(*args, **kwargs): + return view(*args, **kwargs) + + return wrapper + + +def test_accepts_injects_validated_query(app): + @accepts(query=ContractQuery) + def view(*, query): + return query + + with app.test_request_context("/?page=3&limit=5"): + result = view() + + assert isinstance(result, ContractQuery) + assert result.page == 3 + assert result.limit == 5 + + +def test_accepts_query_uses_defaults_when_absent(app): + @accepts(query=ContractQuery) + def view(*, query): + return query + + with app.test_request_context("/"): + result = view() + + assert result.page == 1 + assert result.limit == 20 + + +@pytest.mark.parametrize("query_string", ["page=0", "limit=999", "page=abc", "unknown=1"]) +def test_accepts_rejects_invalid_query_with_422(app, query_string): + @accepts(query=ContractQuery) + def view(*, query): + return query + + with app.test_request_context(f"/?{query_string}"): + with pytest.raises(UnprocessableEntity): + view() + + +def test_accepts_validation_error_is_sanitized_and_structured(app): + """422 body is structured and leaks neither the pydantic docs url nor the user input.""" + + @accepts(body=ContractBody) + def view(*, body): + return body + + with app.test_request_context("/", method="POST", json={"secret": "leak-me"}): + with pytest.raises(UnprocessableEntity) as exc_info: + view() + + data = exc_info.value.data + assert data["message"] == "Request validation failed" + assert isinstance(data["errors"], list) + assert data["errors"] + for err in data["errors"]: + assert {"type", "loc", "msg"} <= err.keys() + assert "url" not in err + assert "input" not in err + assert "leak-me" not in str(data) + + +def test_accepts_injects_validated_body(app): + @accepts(body=ContractBody) + def view(*, body): + return body + + with app.test_request_context("/", method="POST", json={"name": "x"}): + result = view() + + assert isinstance(result, ContractBody) + assert result.name == "x" + + +def test_accepts_rejects_invalid_body_with_422(app): + @accepts(body=ContractBody) + def view(*, body): + return body + + with app.test_request_context("/", method="POST", json={"wrong": 1}): + with pytest.raises(UnprocessableEntity): + view() + + +def test_returns_serializes_model_with_decorator_status(app): + @returns(200, ContractResp) + def view(): + return ContractResp(value=7) + + with app.test_request_context("/"): + body, status = view() + + assert status == 200 + assert body == {"value": 7} + + +def test_returns_serializes_model_in_tuple_and_honors_status(app): + @returns(200, ContractResp) + def view(): + return ContractResp(value=9), 201 + + with app.test_request_context("/"): + body, status = view() + + assert status == 201 + assert body == {"value": 9} + + +def test_returns_passes_through_non_model(app): + sentinel = object() + + @returns(200, ContractResp) + def view(): + return sentinel + + with app.test_request_context("/"): + result = view() + + assert result is sentinel + + +def test_returns_serializes_model_in_three_tuple_with_headers(app): + """A (model, status, headers) tuple keeps its trailing status/headers intact.""" + + @returns(200, ContractResp) + def view(): + return ContractResp(value=3), 202, {"X-Test": "1"} + + with app.test_request_context("/"): + body, status, headers = view() + + assert body == {"value": 3} + assert status == 202 + assert headers == {"X-Test": "1"} + + +# Swagger metadata (read off __apidoc__) must survive @wraps up through the guard layer. + + +def test_accepts_returns_emit_apidoc_through_guard_stack(): + @_guard_like + @returns(200, ContractResp) + @accepts(query=ContractQuery) + def view(*, query): + return ContractResp(value=1) + + apidoc = getattr(view, "__apidoc__", {}) + assert "page" in apidoc.get("params", {}) # from @accepts(query=) + assert "200" in apidoc.get("responses", {}) # from @returns (flask_restx keys by str code) + + +def test_accepts_body_emits_expect_through_guard_stack(): + @_guard_like + @accepts(body=ContractBody) + def view(*, body): + return body + + apidoc = getattr(view, "__apidoc__", {}) + assert apidoc.get("expect") # body schema advertised via @openapi_ns.expect diff --git a/api/tests/unit_tests/controllers/openapi/test_human_input_form.py b/api/tests/unit_tests/controllers/openapi/test_human_input_form.py index da4289bdde..f8d296deb3 100644 --- a/api/tests/unit_tests/controllers/openapi/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/openapi/test_human_input_form.py @@ -11,7 +11,7 @@ from unittest.mock import Mock import pytest from flask import Flask -from werkzeug.exceptions import NotFound +from werkzeug.exceptions import NotFound, UnprocessableEntity from controllers.openapi.auth.data import AuthData from libs.oauth_bearer import Scope, TokenType @@ -233,3 +233,24 @@ class TestOpenApiHumanInputFormPost: submission_end_user_id="eu-7", ) assert result == ({}, 200) + + def test_post_rejects_invalid_body_with_422(self, app: Flask, bypass_pipeline): + """Malformed body → 422 via @accepts (was an unmapped pydantic error → 500).""" + from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi + + api = OpenApiWorkflowHumanInputFormApi() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-42") + + with app.test_request_context( + "/openapi/v1/apps/app-1/form/human_input/tok-1", + method="POST", + json={"inputs": {"field1": "val"}}, # missing required "action" + ): + with pytest.raises(UnprocessableEntity): + api.post.__wrapped__( + api, + app_id="app-1", + form_token="tok-1", + auth_data=_make_auth_data(app_model, caller, "account"), + ) diff --git a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py index 548b58286e..6bb13ad322 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py +++ b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py @@ -29,7 +29,7 @@ import pytest from flask import Flask from flask.views import MethodView from pydantic import ValidationError -from werkzeug.exceptions import BadRequest, Forbidden, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, NotFound, UnprocessableEntity from controllers.openapi import bp as openapi_bp from controllers.openapi._models import MemberInvitePayload, MemberRoleUpdatePayload @@ -198,7 +198,7 @@ def test_member_role_route_registered(openapi_app: Flask): # --------------------------------------------------------------------------- -# Payload validation lands at 400 +# Payload validation lands at 422 (unified via @accepts) # --------------------------------------------------------------------------- @@ -227,18 +227,38 @@ def test_role_payload_rejects_extra_field(): MemberRoleUpdatePayload.model_validate({"role": "normal", "extra": "x"}) -def test_validate_body_helper_maps_validation_error_to_400(app, monkeypatch): - """`_validate_body` is the centralized 400-mapper for invalid request bodies.""" - from controllers.openapi.workspaces import _validate_body +def test_invite_rejects_invalid_body_with_422(app, bypass_pipeline): + """Invalid invite body → 422 via @accepts (was 400 through _validate_body).""" + ws_id = str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMembersApi() with app.test_request_context( - "/openapi/v1/workspaces/ws-1/members", + f"/openapi/v1/workspaces/{ws_id}/members", method="POST", - data=json.dumps({"email": "u@example.com", "role": "owner"}), + data=json.dumps({"email": "u@example.com", "role": "owner"}), # owner is not invite-assignable content_type="application/json", ): - with pytest.raises(BadRequest): - _validate_body(MemberInvitePayload) + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(UnprocessableEntity): + api.post.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) + + +def test_update_role_rejects_invalid_body_with_422(app, bypass_pipeline): + """Invalid role-update body surfaces as 422 through @accepts (was 400).""" + ws_id, member_id = str(uuid.uuid4()), str(uuid.uuid4()) + acct_id = uuid.uuid4() + api = WorkspaceMemberRoleApi() + + with app.test_request_context( + f"/openapi/v1/workspaces/{ws_id}/members/{member_id}/role", + method="PUT", + data=json.dumps({"role": "owner"}), # closed enum rejects owner + content_type="application/json", + ): + _seed(_auth_ctx(account_id=acct_id)) + with pytest.raises(UnprocessableEntity): + api.put.__wrapped__(api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id)) # --------------------------------------------------------------------------- @@ -384,7 +404,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypatch): - """Strict (`extra='forbid'`) — typos like `?pg=2` surface as 400.""" + """Strict (`extra='forbid'`) — typos like `?pg=2` surface as 422 (unified via @accepts).""" ws_id = str(uuid.uuid4()) acct_id = uuid.uuid4() api = WorkspaceMembersApi() @@ -395,7 +415,7 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"): _seed(_auth_ctx(account_id=acct_id)) - with pytest.raises(BadRequest): + with pytest.raises(UnprocessableEntity): api.get.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) diff --git a/packages/contracts/generated/api/openapi/orpc.gen.ts b/packages/contracts/generated/api/openapi/orpc.gen.ts index 4fa3c53614..f909cdffe8 100644 --- a/packages/contracts/generated/api/openapi/orpc.gen.ts +++ b/packages/contracts/generated/api/openapi/orpc.gen.ts @@ -24,6 +24,7 @@ import { zGetHealthResponse, zGetOauthDeviceLookupQuery, zGetOauthDeviceLookupResponse, + zGetPermittedExternalAppsQuery, zGetPermittedExternalAppsResponse, zGetVersionResponse, zGetWorkspacesByWorkspaceIdMembersPath, @@ -438,6 +439,7 @@ export const get10 = oc path: '/permitted-external-apps', tags: ['openapi'], }) + .input(z.object({ query: zGetPermittedExternalAppsQuery.optional() })) .output(zGetPermittedExternalAppsResponse) export const permittedExternalApps = { diff --git a/packages/contracts/generated/api/openapi/types.gen.ts b/packages/contracts/generated/api/openapi/types.gen.ts index ccf98cdf52..7fcd742db2 100644 --- a/packages/contracts/generated/api/openapi/types.gen.ts +++ b/packages/contracts/generated/api/openapi/types.gen.ts @@ -656,7 +656,12 @@ export type PostOauthDeviceTokenResponse export type GetPermittedExternalAppsData = { body?: never path?: never - query?: never + query?: { + limit?: number + mode?: string + name?: string + page?: number + } url: '/permitted-external-apps' } diff --git a/packages/contracts/generated/api/openapi/zod.gen.ts b/packages/contracts/generated/api/openapi/zod.gen.ts index f34d6f4bb8..e143752736 100644 --- a/packages/contracts/generated/api/openapi/zod.gen.ts +++ b/packages/contracts/generated/api/openapi/zod.gen.ts @@ -638,6 +638,13 @@ export const zPostOauthDeviceTokenBody = zDevicePollRequest */ export const zPostOauthDeviceTokenResponse = z.record(z.string(), z.unknown()) +export const zGetPermittedExternalAppsQuery = z.object({ + limit: z.int().gte(1).lte(200).optional().default(20), + mode: z.string().optional(), + name: z.string().max(200).optional(), + page: z.int().gte(1).optional().default(1), +}) + /** * Permitted external apps list */