chore: dep inject for model (#36750)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
Asuka Minato 2026-05-31 02:40:46 +09:00 committed by GitHub
parent 599960024d
commit df40960f5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 203 additions and 18 deletions

View File

@ -12,8 +12,9 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, model_validate, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
@ -33,6 +34,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
register_schema_models(console_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
@ -76,7 +79,9 @@ class ConsoleHumanInputFormApi(Resource):
@account_initialization_required
@login_required
def post(self, form_token: str):
@model_validate(HumanInputFormSubmitPayload)
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
def post(self, payload: HumanInputFormSubmitPayload, form_token: str):
"""
Submit human input form by form token.
@ -90,7 +95,6 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)

View File

@ -7,7 +7,9 @@ from functools import wraps
from typing import Concatenate
from flask import abort, request
from pydantic import BaseModel, ValidationError
from sqlalchemy import select
from werkzeug.exceptions import UnprocessableEntity
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
@ -518,3 +520,38 @@ def with_current_user[T, **P, R](
return view(self, current_user, *args, **kwargs)
return decorated
def model_validate[T, M: BaseModel, **P, R](
model: type[M],
) -> Callable[
[Callable[Concatenate[T, M, P], R]],
Callable[Concatenate[T, P], R],
]:
"""Validate request data and inject the model instance as the first arg after self.
Source is determined by HTTP method:
GET/DELETE -> request.args
POST/PUT/PATCH -> JSON body
"""
def decorator(
view: Callable[Concatenate[T, M, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if request.method in ("GET", "DELETE"):
raw = request.args.to_dict(flat=True)
else:
raw = request.get_json(silent=True) or {}
try:
validated = model.model_validate(raw)
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
return view(self, validated, *args, **kwargs)
return wrapper
return decorator

View File

@ -6051,6 +6051,7 @@ Request body:
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| form_token | path | | Yes | string |
| payload | body | | Yes | [HumanInputFormSubmitPayload](#humaninputformsubmitpayload) |
##### Responses
@ -13720,6 +13721,12 @@ Request payload for bulk downloading documents as a zip archive.
| ---- | ---- | ----------- | -------- |
| JSONValue | | | |
#### JsonValue
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| JsonValue | | | |
#### KnowledgeConfig
| Name | Type | Description | Required |

View File

@ -6,8 +6,9 @@ from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Response
from flask import Flask, Response
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.console.human_input_form import (
ConsoleHumanInputFormApi,
ConsoleWorkflowEventsApi,
@ -16,6 +17,7 @@ from controllers.console.human_input_form import (
_jsonify_form_definition,
)
from controllers.web.error import NotFoundError
from models.account import AccountStatus
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
@ -47,7 +49,7 @@ def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
ConsoleHumanInputFormApi._ensure_console_access(form)
def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
expiration = datetime(2024, 1, 1, tzinfo=UTC)
definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]})
form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration)
@ -73,7 +75,7 @@ def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> No
assert payload["fields"] == ["a"]
def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
@ -93,7 +95,7 @@ def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) ->
handler(api, form_token="token")
def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER)
class _ServiceStub:
@ -119,10 +121,14 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
handler(
api,
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
form_token="token",
)
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
class _ServiceStub:
@ -148,10 +154,14 @@ def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.Monkey
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
handler(
api,
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
form_token="token",
)
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
@ -180,13 +190,61 @@ def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
response = handler(api, form_token="token")
response = handler(
api,
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
form_token="token",
)
assert response.get_json() == {}
submit_mock.assert_called_once()
def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_post_form_decorated_success_validates_request_body(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
current_user = SimpleNamespace(id="user-1", status=AccountStatus.ACTIVE)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
def submit_form_by_token(self, **kwargs):
submit_mock(**kwargs)
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (current_user, "tenant-1"),
)
monkeypatch.setattr(
"controllers.console.wraps.current_account_with_tenant",
lambda: (current_user, "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
response = ConsoleHumanInputFormApi().post(form_token="token")
assert response.get_json() == {}
submit_mock.assert_called_once_with(
recipient_type=RecipientType.CONSOLE,
form_token="token",
selected_action_id="approve",
form_data={"content": "ok"},
submission_user_id="user-1",
)
def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return None
@ -210,7 +268,7 @@ def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None
handler(api, workflow_run_id="run-1")
def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.END_USER,
@ -241,7 +299,7 @@ def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch)
handler(api, workflow_run_id="run-1")
def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.ACCOUNT,
@ -272,7 +330,7 @@ def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch)
handler(api, workflow_run_id="run-1")
def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.ACCOUNT,

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_login import LoginManager, UserMixin
from pydantic import BaseModel
from werkzeug.exceptions import HTTPException
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
@ -14,6 +15,7 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
cloud_utm_record,
enterprise_license_required,
model_validate,
only_edition_cloud,
only_edition_enterprise,
only_edition_self_hosted,
@ -135,6 +137,56 @@ class TestCurrentContextInjection:
assert Handler().get() == ("tenant-123", current_user)
class TestModelValidationInjection:
"""Test request model validation decorator."""
class Payload(BaseModel):
name: str
count: int
def test_should_inject_payload_from_json_body(self):
app = Flask(__name__)
class Handler:
@model_validate(TestModelValidationInjection.Payload)
def post(self, payload: TestModelValidationInjection.Payload, item_id: str):
return payload, item_id
with app.test_request_context("/items/item-1", method="POST", json={"name": "alpha", "count": "2"}):
payload, item_id = Handler().post(item_id="item-1")
assert payload == self.Payload(name="alpha", count=2)
assert item_id == "item-1"
def test_should_inject_payload_from_query_params(self):
app = Flask(__name__)
class Handler:
@model_validate(TestModelValidationInjection.Payload)
def get(self, payload: TestModelValidationInjection.Payload):
return payload
with app.test_request_context("/items?name=alpha&count=2", method="GET"):
payload = Handler().get()
assert payload == self.Payload(name="alpha", count=2)
def test_should_raise_unprocessable_entity_for_invalid_payload(self):
app = Flask(__name__)
class Handler:
@model_validate(TestModelValidationInjection.Payload)
def post(self, payload: TestModelValidationInjection.Payload):
return payload
with app.test_request_context("/items", method="POST", json={"name": "alpha"}):
with pytest.raises(HTTPException) as exc_info:
Handler().post()
assert exc_info.value.code == 422
assert "count" in exc_info.value.description
class TestEditionChecks:
"""Test edition-specific decorators"""

View File

@ -6,6 +6,7 @@ import * as z from 'zod'
import {
zGetFormHumanInputByFormTokenPath,
zGetFormHumanInputByFormTokenResponse,
zPostFormHumanInputByFormTokenBody,
zPostFormHumanInputByFormTokenPath,
zPostFormHumanInputByFormTokenResponse,
} from './zod.gen'
@ -63,7 +64,12 @@ export const post = oc
summary: 'Submit human input form by form token',
tags: ['console'],
})
.input(z.object({ params: zPostFormHumanInputByFormTokenPath }))
.input(
z.object({
body: zPostFormHumanInputByFormTokenBody,
params: zPostFormHumanInputByFormTokenPath,
}),
)
.output(zPostFormHumanInputByFormTokenResponse)
export const byFormToken = {

View File

@ -4,6 +4,16 @@ export type ClientOptions = {
baseUrl: `${string}://${string}/console/api` | (string & {})
}
export type HumanInputFormSubmitPayload = {
action: string
form_inputs: {
[key: string]: unknown
}
inputs: {
[key: string]: unknown
}
}
export type GetFormHumanInputByFormTokenData = {
body?: never
path: {
@ -23,7 +33,7 @@ export type GetFormHumanInputByFormTokenResponse
= GetFormHumanInputByFormTokenResponses[keyof GetFormHumanInputByFormTokenResponses]
export type PostFormHumanInputByFormTokenData = {
body?: never
body: HumanInputFormSubmitPayload
path: {
form_token: string
}

View File

@ -2,6 +2,15 @@
import * as z from 'zod'
/**
* HumanInputFormSubmitPayload
*/
export const zHumanInputFormSubmitPayload = z.object({
action: z.string(),
form_inputs: z.record(z.string(), z.unknown()),
inputs: z.record(z.string(), z.unknown()),
})
export const zGetFormHumanInputByFormTokenPath = z.object({
form_token: z.string(),
})
@ -11,6 +20,8 @@ export const zGetFormHumanInputByFormTokenPath = z.object({
*/
export const zGetFormHumanInputByFormTokenResponse = z.record(z.string(), z.unknown())
export const zPostFormHumanInputByFormTokenBody = zHumanInputFormSubmitPayload
export const zPostFormHumanInputByFormTokenPath = z.object({
form_token: z.string(),
})