diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 79b3e6cc9f..2ce6bc3e6d 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -12,8 +12,9 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from controllers.common.human_input import HumanInputFormSubmitPayload +from controllers.common.schema import register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, model_validate, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator @@ -33,6 +34,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream logger = logging.getLogger(__name__) +register_schema_models(console_ns, HumanInputFormSubmitPayload) + def _jsonify_form_definition(form: Form) -> Response: payload = form.get_definition().model_dump() @@ -76,7 +79,9 @@ class ConsoleHumanInputFormApi(Resource): @account_initialization_required @login_required - def post(self, form_token: str): + @model_validate(HumanInputFormSubmitPayload) + @console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__]) + def post(self, payload: HumanInputFormSubmitPayload, form_token: str): """ Submit human input form by form token. @@ -90,7 +95,6 @@ class ConsoleHumanInputFormApi(Resource): "action": "Approve" } """ - payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) current_user, _ = current_account_with_tenant() service = HumanInputService(db.engine) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 603645278f..0a7ba552ee 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -7,7 +7,9 @@ from functools import wraps from typing import Concatenate from flask import abort, request +from pydantic import BaseModel, ValidationError from sqlalchemy import select +from werkzeug.exceptions import UnprocessableEntity from configs import dify_config from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError @@ -518,3 +520,38 @@ def with_current_user[T, **P, R]( return view(self, current_user, *args, **kwargs) return decorated + + +def model_validate[T, M: BaseModel, **P, R]( + model: type[M], +) -> Callable[ + [Callable[Concatenate[T, M, P], R]], + Callable[Concatenate[T, P], R], +]: + """Validate request data and inject the model instance as the first arg after self. + + Source is determined by HTTP method: + GET/DELETE -> request.args + POST/PUT/PATCH -> JSON body + """ + + def decorator( + view: Callable[Concatenate[T, M, P], R], + ) -> Callable[Concatenate[T, P], R]: + @wraps(view) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R: + if request.method in ("GET", "DELETE"): + raw = request.args.to_dict(flat=True) + else: + raw = request.get_json(silent=True) or {} + + try: + validated = model.model_validate(raw) + except ValidationError as exc: + raise UnprocessableEntity(exc.json()) + + return view(self, validated, *args, **kwargs) + + return wrapper + + return decorator diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index 3bbbc75f71..abd5cb37f8 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -6051,6 +6051,7 @@ Request body: | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | | form_token | path | | Yes | string | +| payload | body | | Yes | [HumanInputFormSubmitPayload](#humaninputformsubmitpayload) | ##### Responses @@ -13720,6 +13721,12 @@ Request payload for bulk downloading documents as a zip archive. | ---- | ---- | ----------- | -------- | | JSONValue | | | | +#### JsonValue + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| JsonValue | | | | + #### KnowledgeConfig | Name | Type | Description | Required | diff --git a/api/tests/unit_tests/controllers/console/test_human_input_form.py b/api/tests/unit_tests/controllers/console/test_human_input_form.py index ebf803cac9..11c9c0275b 100644 --- a/api/tests/unit_tests/controllers/console/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/console/test_human_input_form.py @@ -6,8 +6,9 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from flask import Response +from flask import Flask, Response +from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.console.human_input_form import ( ConsoleHumanInputFormApi, ConsoleWorkflowEventsApi, @@ -16,6 +17,7 @@ from controllers.console.human_input_form import ( _jsonify_form_definition, ) from controllers.web.error import NotFoundError +from models.account import AccountStatus from models.enums import CreatorUserRole from models.human_input import RecipientType from models.model import AppMode @@ -47,7 +49,7 @@ def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None: ConsoleHumanInputFormApi._ensure_console_access(form) -def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: expiration = datetime(2024, 1, 1, tzinfo=UTC) definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]}) form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration) @@ -73,7 +75,7 @@ def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> No assert payload["fields"] == ["a"] -def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: class _ServiceStub: def __init__(self, *_args, **_kwargs): pass @@ -93,7 +95,7 @@ def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> handler(api, form_token="token") -def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER) class _ServiceStub: @@ -119,10 +121,14 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) json={"inputs": {"content": "ok"}, "action": "approve"}, ): with pytest.raises(NotFoundError): - handler(api, form_token="token") + handler( + api, + HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}), + form_token="token", + ) -def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP) class _ServiceStub: @@ -148,10 +154,14 @@ def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.Monkey json={"inputs": {"content": "ok"}, "action": "approve"}, ): with pytest.raises(NotFoundError): - handler(api, form_token="token") + handler( + api, + HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}), + form_token="token", + ) -def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: submit_mock = Mock() form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE) @@ -180,13 +190,61 @@ def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None: method="POST", json={"inputs": {"content": "ok"}, "action": "approve"}, ): - response = handler(api, form_token="token") + response = handler( + api, + HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}), + form_token="token", + ) assert response.get_json() == {} submit_mock.assert_called_once() -def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_post_form_decorated_success_validates_request_body(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + submit_mock = Mock() + form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE) + current_user = SimpleNamespace(id="user-1", status=AccountStatus.ACTIVE) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_by_token(self, _token): + return form + + def submit_form_by_token(self, **kwargs): + submit_mock(**kwargs) + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (current_user, "tenant-1"), + ) + monkeypatch.setattr( + "controllers.console.wraps.current_account_with_tenant", + lambda: (current_user, "tenant-1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True) + + with app.test_request_context( + "/console/api/form/human_input/token", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + response = ConsoleHumanInputFormApi().post(form_token="token") + + assert response.get_json() == {} + submit_mock.assert_called_once_with( + recipient_type=RecipientType.CONSOLE, + form_token="token", + selected_action_id="approve", + form_data={"content": "ok"}, + submission_user_id="user-1", + ) + + +def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: class _RepoStub: def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): return None @@ -210,7 +268,7 @@ def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None handler(api, workflow_run_id="run-1") -def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", created_by_role=CreatorUserRole.END_USER, @@ -241,7 +299,7 @@ def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) handler(api, workflow_run_id="run-1") -def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", created_by_role=CreatorUserRole.ACCOUNT, @@ -272,7 +330,7 @@ def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) handler(api, workflow_run_id="run-1") -def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", created_by_role=CreatorUserRole.ACCOUNT, diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index c392ffc69d..714b114752 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_login import LoginManager, UserMixin +from pydantic import BaseModel from werkzeug.exceptions import HTTPException from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout @@ -14,6 +15,7 @@ from controllers.console.wraps import ( cloud_edition_billing_resource_check, cloud_utm_record, enterprise_license_required, + model_validate, only_edition_cloud, only_edition_enterprise, only_edition_self_hosted, @@ -135,6 +137,56 @@ class TestCurrentContextInjection: assert Handler().get() == ("tenant-123", current_user) +class TestModelValidationInjection: + """Test request model validation decorator.""" + + class Payload(BaseModel): + name: str + count: int + + def test_should_inject_payload_from_json_body(self): + app = Flask(__name__) + + class Handler: + @model_validate(TestModelValidationInjection.Payload) + def post(self, payload: TestModelValidationInjection.Payload, item_id: str): + return payload, item_id + + with app.test_request_context("/items/item-1", method="POST", json={"name": "alpha", "count": "2"}): + payload, item_id = Handler().post(item_id="item-1") + + assert payload == self.Payload(name="alpha", count=2) + assert item_id == "item-1" + + def test_should_inject_payload_from_query_params(self): + app = Flask(__name__) + + class Handler: + @model_validate(TestModelValidationInjection.Payload) + def get(self, payload: TestModelValidationInjection.Payload): + return payload + + with app.test_request_context("/items?name=alpha&count=2", method="GET"): + payload = Handler().get() + + assert payload == self.Payload(name="alpha", count=2) + + def test_should_raise_unprocessable_entity_for_invalid_payload(self): + app = Flask(__name__) + + class Handler: + @model_validate(TestModelValidationInjection.Payload) + def post(self, payload: TestModelValidationInjection.Payload): + return payload + + with app.test_request_context("/items", method="POST", json={"name": "alpha"}): + with pytest.raises(HTTPException) as exc_info: + Handler().post() + + assert exc_info.value.code == 422 + assert "count" in exc_info.value.description + + class TestEditionChecks: """Test edition-specific decorators""" diff --git a/packages/contracts/generated/api/console/form/orpc.gen.ts b/packages/contracts/generated/api/console/form/orpc.gen.ts index 0d30da9d00..d28f1b4bb4 100644 --- a/packages/contracts/generated/api/console/form/orpc.gen.ts +++ b/packages/contracts/generated/api/console/form/orpc.gen.ts @@ -6,6 +6,7 @@ import * as z from 'zod' import { zGetFormHumanInputByFormTokenPath, zGetFormHumanInputByFormTokenResponse, + zPostFormHumanInputByFormTokenBody, zPostFormHumanInputByFormTokenPath, zPostFormHumanInputByFormTokenResponse, } from './zod.gen' @@ -63,7 +64,12 @@ export const post = oc summary: 'Submit human input form by form token', tags: ['console'], }) - .input(z.object({ params: zPostFormHumanInputByFormTokenPath })) + .input( + z.object({ + body: zPostFormHumanInputByFormTokenBody, + params: zPostFormHumanInputByFormTokenPath, + }), + ) .output(zPostFormHumanInputByFormTokenResponse) export const byFormToken = { diff --git a/packages/contracts/generated/api/console/form/types.gen.ts b/packages/contracts/generated/api/console/form/types.gen.ts index 80c0c1a474..fb908f1c70 100644 --- a/packages/contracts/generated/api/console/form/types.gen.ts +++ b/packages/contracts/generated/api/console/form/types.gen.ts @@ -4,6 +4,16 @@ export type ClientOptions = { baseUrl: `${string}://${string}/console/api` | (string & {}) } +export type HumanInputFormSubmitPayload = { + action: string + form_inputs: { + [key: string]: unknown + } + inputs: { + [key: string]: unknown + } +} + export type GetFormHumanInputByFormTokenData = { body?: never path: { @@ -23,7 +33,7 @@ export type GetFormHumanInputByFormTokenResponse = GetFormHumanInputByFormTokenResponses[keyof GetFormHumanInputByFormTokenResponses] export type PostFormHumanInputByFormTokenData = { - body?: never + body: HumanInputFormSubmitPayload path: { form_token: string } diff --git a/packages/contracts/generated/api/console/form/zod.gen.ts b/packages/contracts/generated/api/console/form/zod.gen.ts index 840b04383e..8d74f49963 100644 --- a/packages/contracts/generated/api/console/form/zod.gen.ts +++ b/packages/contracts/generated/api/console/form/zod.gen.ts @@ -2,6 +2,15 @@ import * as z from 'zod' +/** + * HumanInputFormSubmitPayload + */ +export const zHumanInputFormSubmitPayload = z.object({ + action: z.string(), + form_inputs: z.record(z.string(), z.unknown()), + inputs: z.record(z.string(), z.unknown()), +}) + export const zGetFormHumanInputByFormTokenPath = z.object({ form_token: z.string(), }) @@ -11,6 +20,8 @@ export const zGetFormHumanInputByFormTokenPath = z.object({ */ export const zGetFormHumanInputByFormTokenResponse = z.record(z.string(), z.unknown()) +export const zPostFormHumanInputByFormTokenBody = zHumanInputFormSubmitPayload + export const zPostFormHumanInputByFormTokenPath = z.object({ form_token: z.string(), })