mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
chore: dep inject for model (#36750)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
parent
599960024d
commit
df40960f5d
@ -12,8 +12,9 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, model_validate, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@ -33,6 +34,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(console_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
@ -76,7 +79,9 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def post(self, form_token: str):
|
||||
@model_validate(HumanInputFormSubmitPayload)
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
def post(self, payload: HumanInputFormSubmitPayload, form_token: str):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
@ -90,7 +95,6 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
|
||||
@ -7,7 +7,9 @@ from functools import wraps
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import abort, request
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
|
||||
@ -518,3 +520,38 @@ def with_current_user[T, **P, R](
|
||||
return view(self, current_user, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def model_validate[T, M: BaseModel, **P, R](
|
||||
model: type[M],
|
||||
) -> Callable[
|
||||
[Callable[Concatenate[T, M, P], R]],
|
||||
Callable[Concatenate[T, P], R],
|
||||
]:
|
||||
"""Validate request data and inject the model instance as the first arg after self.
|
||||
|
||||
Source is determined by HTTP method:
|
||||
GET/DELETE -> request.args
|
||||
POST/PUT/PATCH -> JSON body
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
view: Callable[Concatenate[T, M, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if request.method in ("GET", "DELETE"):
|
||||
raw = request.args.to_dict(flat=True)
|
||||
else:
|
||||
raw = request.get_json(silent=True) or {}
|
||||
|
||||
try:
|
||||
validated = model.model_validate(raw)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
return view(self, validated, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@ -6051,6 +6051,7 @@ Request body:
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| form_token | path | | Yes | string |
|
||||
| payload | body | | Yes | [HumanInputFormSubmitPayload](#humaninputformsubmitpayload) |
|
||||
|
||||
##### Responses
|
||||
|
||||
@ -13720,6 +13721,12 @@ Request payload for bulk downloading documents as a zip archive.
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| JSONValue | | | |
|
||||
|
||||
#### JsonValue
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| JsonValue | | | |
|
||||
|
||||
#### KnowledgeConfig
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
|
||||
@ -6,8 +6,9 @@ from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
from flask import Flask, Response
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.console.human_input_form import (
|
||||
ConsoleHumanInputFormApi,
|
||||
ConsoleWorkflowEventsApi,
|
||||
@ -16,6 +17,7 @@ from controllers.console.human_input_form import (
|
||||
_jsonify_form_definition,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from models.account import AccountStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
@ -47,7 +49,7 @@ def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
ConsoleHumanInputFormApi._ensure_console_access(form)
|
||||
|
||||
|
||||
def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
expiration = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]})
|
||||
form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration)
|
||||
@ -73,7 +75,7 @@ def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> No
|
||||
assert payload["fields"] == ["a"]
|
||||
|
||||
|
||||
def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
@ -93,7 +95,7 @@ def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER)
|
||||
|
||||
class _ServiceStub:
|
||||
@ -119,10 +121,14 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
handler(
|
||||
api,
|
||||
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
|
||||
form_token="token",
|
||||
)
|
||||
|
||||
|
||||
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
|
||||
class _ServiceStub:
|
||||
@ -148,10 +154,14 @@ def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.Monkey
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
handler(
|
||||
api,
|
||||
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
|
||||
form_token="token",
|
||||
)
|
||||
|
||||
|
||||
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock = Mock()
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
|
||||
|
||||
@ -180,13 +190,61 @@ def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
response = handler(api, form_token="token")
|
||||
response = handler(
|
||||
api,
|
||||
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
|
||||
form_token="token",
|
||||
)
|
||||
|
||||
assert response.get_json() == {}
|
||||
submit_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_post_form_decorated_success_validates_request_body(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock = Mock()
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
|
||||
current_user = SimpleNamespace(id="user-1", status=AccountStatus.ACTIVE)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
def submit_form_by_token(self, **kwargs):
|
||||
submit_mock(**kwargs)
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (current_user, "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
lambda: (current_user, "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
response = ConsoleHumanInputFormApi().post(form_token="token")
|
||||
|
||||
assert response.get_json() == {}
|
||||
submit_mock.assert_called_once_with(
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
form_token="token",
|
||||
selected_action_id="approve",
|
||||
form_data={"content": "ok"},
|
||||
submission_user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return None
|
||||
@ -210,7 +268,7 @@ def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
@ -241,7 +299,7 @@ def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch)
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
@ -272,7 +330,7 @@ def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch)
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager, UserMixin
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
@ -14,6 +15,7 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
cloud_utm_record,
|
||||
enterprise_license_required,
|
||||
model_validate,
|
||||
only_edition_cloud,
|
||||
only_edition_enterprise,
|
||||
only_edition_self_hosted,
|
||||
@ -135,6 +137,56 @@ class TestCurrentContextInjection:
|
||||
assert Handler().get() == ("tenant-123", current_user)
|
||||
|
||||
|
||||
class TestModelValidationInjection:
|
||||
"""Test request model validation decorator."""
|
||||
|
||||
class Payload(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
|
||||
def test_should_inject_payload_from_json_body(self):
|
||||
app = Flask(__name__)
|
||||
|
||||
class Handler:
|
||||
@model_validate(TestModelValidationInjection.Payload)
|
||||
def post(self, payload: TestModelValidationInjection.Payload, item_id: str):
|
||||
return payload, item_id
|
||||
|
||||
with app.test_request_context("/items/item-1", method="POST", json={"name": "alpha", "count": "2"}):
|
||||
payload, item_id = Handler().post(item_id="item-1")
|
||||
|
||||
assert payload == self.Payload(name="alpha", count=2)
|
||||
assert item_id == "item-1"
|
||||
|
||||
def test_should_inject_payload_from_query_params(self):
|
||||
app = Flask(__name__)
|
||||
|
||||
class Handler:
|
||||
@model_validate(TestModelValidationInjection.Payload)
|
||||
def get(self, payload: TestModelValidationInjection.Payload):
|
||||
return payload
|
||||
|
||||
with app.test_request_context("/items?name=alpha&count=2", method="GET"):
|
||||
payload = Handler().get()
|
||||
|
||||
assert payload == self.Payload(name="alpha", count=2)
|
||||
|
||||
def test_should_raise_unprocessable_entity_for_invalid_payload(self):
|
||||
app = Flask(__name__)
|
||||
|
||||
class Handler:
|
||||
@model_validate(TestModelValidationInjection.Payload)
|
||||
def post(self, payload: TestModelValidationInjection.Payload):
|
||||
return payload
|
||||
|
||||
with app.test_request_context("/items", method="POST", json={"name": "alpha"}):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
Handler().post()
|
||||
|
||||
assert exc_info.value.code == 422
|
||||
assert "count" in exc_info.value.description
|
||||
|
||||
|
||||
class TestEditionChecks:
|
||||
"""Test edition-specific decorators"""
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import * as z from 'zod'
|
||||
import {
|
||||
zGetFormHumanInputByFormTokenPath,
|
||||
zGetFormHumanInputByFormTokenResponse,
|
||||
zPostFormHumanInputByFormTokenBody,
|
||||
zPostFormHumanInputByFormTokenPath,
|
||||
zPostFormHumanInputByFormTokenResponse,
|
||||
} from './zod.gen'
|
||||
@ -63,7 +64,12 @@ export const post = oc
|
||||
summary: 'Submit human input form by form token',
|
||||
tags: ['console'],
|
||||
})
|
||||
.input(z.object({ params: zPostFormHumanInputByFormTokenPath }))
|
||||
.input(
|
||||
z.object({
|
||||
body: zPostFormHumanInputByFormTokenBody,
|
||||
params: zPostFormHumanInputByFormTokenPath,
|
||||
}),
|
||||
)
|
||||
.output(zPostFormHumanInputByFormTokenResponse)
|
||||
|
||||
export const byFormToken = {
|
||||
|
||||
@ -4,6 +4,16 @@ export type ClientOptions = {
|
||||
baseUrl: `${string}://${string}/console/api` | (string & {})
|
||||
}
|
||||
|
||||
export type HumanInputFormSubmitPayload = {
|
||||
action: string
|
||||
form_inputs: {
|
||||
[key: string]: unknown
|
||||
}
|
||||
inputs: {
|
||||
[key: string]: unknown
|
||||
}
|
||||
}
|
||||
|
||||
export type GetFormHumanInputByFormTokenData = {
|
||||
body?: never
|
||||
path: {
|
||||
@ -23,7 +33,7 @@ export type GetFormHumanInputByFormTokenResponse
|
||||
= GetFormHumanInputByFormTokenResponses[keyof GetFormHumanInputByFormTokenResponses]
|
||||
|
||||
export type PostFormHumanInputByFormTokenData = {
|
||||
body?: never
|
||||
body: HumanInputFormSubmitPayload
|
||||
path: {
|
||||
form_token: string
|
||||
}
|
||||
|
||||
@ -2,6 +2,15 @@
|
||||
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* HumanInputFormSubmitPayload
|
||||
*/
|
||||
export const zHumanInputFormSubmitPayload = z.object({
|
||||
action: z.string(),
|
||||
form_inputs: z.record(z.string(), z.unknown()),
|
||||
inputs: z.record(z.string(), z.unknown()),
|
||||
})
|
||||
|
||||
export const zGetFormHumanInputByFormTokenPath = z.object({
|
||||
form_token: z.string(),
|
||||
})
|
||||
@ -11,6 +20,8 @@ export const zGetFormHumanInputByFormTokenPath = z.object({
|
||||
*/
|
||||
export const zGetFormHumanInputByFormTokenResponse = z.record(z.string(), z.unknown())
|
||||
|
||||
export const zPostFormHumanInputByFormTokenBody = zHumanInputFormSubmitPayload
|
||||
|
||||
export const zPostFormHumanInputByFormTokenPath = z.object({
|
||||
form_token: z.string(),
|
||||
})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user