diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 9f9aa4838c..5c7ea9e69a 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,9 +1,12 @@ import logging +from typing import Literal -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( AppMoreLikeThisDisabledError, @@ -38,6 +41,33 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +class MessageListQuery(BaseModel): + conversation_id: str = Field(description="Conversation UUID") + first_id: str | None = Field(default=None, description="First message ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") + + @field_validator("conversation_id", "first_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class MessageFeedbackPayload(BaseModel): + rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") + + +class MessageMoreLikeThisQuery(BaseModel): + response_mode: Literal["blocking", "streaming"] = Field( + description="Response mode", + ) + + +register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery) + + @web_ns.route("/messages") class MessageListApi(WebApiResource): message_fields = { @@ -68,7 +98,11 @@ class MessageListApi(WebApiResource): @web_ns.doc( params={ "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, - "first_id": {"description": "First message ID for pagination", "type": "string", "required": False}, + "first_id": { + "description": "First message ID for pagination", + "type": "string", + "required": False, + }, "limit": { "description": "Number of messages to return (1-100)", "type": "integer", @@ -93,17 +127,12 @@ class MessageListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("conversation_id", required=True, type=uuid_value, location="args") - .add_argument("first_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict() + query = MessageListQuery.model_validate(raw_args) try: return MessageService.pagination_by_first_id( - app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, end_user, query.conversation_id, query.first_id, query.limit ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -128,7 +157,7 @@ class MessageFeedbackApi(WebApiResource): "enum": ["like", "dislike"], "required": False, }, - "content": {"description": "Feedback content/comment", "type": "string", "required": False}, + "content": {"description": "Feedback content", "type": "string", "required": False}, } ) @web_ns.doc( @@ -145,20 +174,15 @@ class MessageFeedbackApi(WebApiResource): def post(self, app_model, end_user, message_id): message_id = str(message_id) - parser = ( - reqparse.RequestParser() - .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - .add_argument("content", type=str, location="json", default=None) - ) - args = parser.parse_args() + payload = MessageFeedbackPayload.model_validate(web_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, message_id=message_id, user=end_user, - rating=args.get("rating"), - content=args.get("content"), + rating=payload.rating, + content=payload.content, ) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -170,17 +194,7 @@ class MessageFeedbackApi(WebApiResource): class MessageMoreLikeThisApi(WebApiResource): @web_ns.doc("Generate More Like This") @web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).") - @web_ns.doc( - params={ - "message_id": {"description": "Message UUID", "type": "string", "required": True}, - "response_mode": { - "description": "Response mode", - "type": "string", - "enum": ["blocking", "streaming"], - "required": True, - }, - } - ) + @web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -197,12 +211,10 @@ class MessageMoreLikeThisApi(WebApiResource): message_id = str(message_id) - parser = reqparse.RequestParser().add_argument( - "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" - ) - args = parser.parse_args() + raw_args = request.args.to_dict() + query = MessageMoreLikeThisQuery.model_validate(raw_args) - streaming = args["response_mode"] == "streaming" + streaming = query.response_mode == "streaming" try: response = AppGenerateService.generate_more_like_this(