From 778aabb4859199db7c9f0eafc690d6d3cb2745b6 Mon Sep 17 00:00:00 2001 From: Sean Kenneth Doherty Date: Wed, 4 Feb 2026 00:36:52 -0600 Subject: [PATCH] refactor(api): replace reqparse with Pydantic models in trial.py (#31789) Co-authored-by: Asuka Minato --- api/controllers/console/explore/trial.py | 101 ++++++++++++++++------- 1 file changed, 71 insertions(+), 30 deletions(-) diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index cd523b481c..ba214e71c0 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -1,8 +1,9 @@ import logging -from typing import Any, cast +from typing import Any, Literal, cast from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -117,7 +118,56 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy) +# Pydantic models for request validation +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowRunRequest(BaseModel): + inputs: dict + files: list | None = None + + +class ChatRequest(BaseModel): + inputs: dict + query: str + files: list | None = None + conversation_id: str | None = None + parent_message_id: str | None = None + retriever_from: str = "explore_app" + + +class TextToSpeechRequest(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + +class CompletionRequest(BaseModel): + inputs: dict + query: str = "" + files: list | None = None + response_mode: Literal["blocking", "streaming"] | None = None + retriever_from: str = "explore_app" + + +# Register schemas for Swagger documentation +console_ns.schema_model( + WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + + class TrialAppWorkflowRunApi(TrialAppResource): + @console_ns.expect(console_ns.models[WorkflowRunRequest.__name__]) def post(self, trial_app): """ Run workflow @@ -129,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") - args = parser.parse_args() + request_data = WorkflowRunRequest.model_validate(console_ns.payload) + args = request_data.model_dump() assert current_user is not None try: app_id = app_model.id @@ -183,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource): class TrialChatApi(TrialAppResource): + @console_ns.expect(console_ns.models[ChatRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app @@ -190,14 +239,14 @@ class TrialChatApi(TrialAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") - parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - args = parser.parse_args() + request_data = ChatRequest.model_validate(console_ns.payload) + args = request_data.model_dump() + + # Validate UUID values if provided + if args.get("conversation_id"): + args["conversation_id"] = uuid_value(args["conversation_id"]) + if args.get("parent_message_id"): + args["parent_message_id"] = uuid_value(args["parent_message_id"]) args["auto_generate_name"] = False @@ -320,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource): class TrialChatTextApi(TrialAppResource): + @console_ns.expect(console_ns.models[TextToSpeechRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, required=False, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") - args = parser.parse_args() + request_data = TextToSpeechRequest.model_validate(console_ns.payload) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = request_data.message_id + text = request_data.text + voice = request_data.voice if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") @@ -371,19 +416,15 @@ class TrialChatTextApi(TrialAppResource): class TrialCompletionApi(TrialAppResource): + @console_ns.expect(console_ns.models[CompletionRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - args = parser.parse_args() + request_data = CompletionRequest.model_validate(console_ns.payload) + args = request_data.model_dump() streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False