From 9121f24181fea6e2d0c2105fb6894fc2bd93931b Mon Sep 17 00:00:00 2001 From: Jake Armstrong <65635253+jakearmstrong59@users.noreply.github.com> Date: Mon, 13 Apr 2026 07:27:35 +0200 Subject: [PATCH] refactor(api): deduplicate TextToAudioPayload and MessageListQuery into controller_schemas.py (#34757) --- api/controllers/common/controller_schemas.py | 14 +++++++------- api/controllers/service_api/app/audio.py | 9 +-------- api/controllers/web/message.py | 18 ++---------------- .../controllers/web/test_pydantic_models.py | 4 ++-- 4 files changed, 12 insertions(+), 33 deletions(-) diff --git a/api/controllers/common/controller_schemas.py b/api/controllers/common/controller_schemas.py index 39e3b5857d..ec5c72374d 100644 --- a/api/controllers/common/controller_schemas.py +++ b/api/controllers/common/controller_schemas.py @@ -23,9 +23,9 @@ class ConversationRenamePayload(BaseModel): class MessageListQuery(BaseModel): - conversation_id: UUIDStrOrEmpty - first_id: UUIDStrOrEmpty | None = None - limit: int = Field(default=20, ge=1, le=100) + conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID") + first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") class MessageFeedbackPayload(BaseModel): @@ -73,7 +73,7 @@ class WorkflowUpdatePayload(BaseModel): class TextToAudioPayload(BaseModel): - message_id: str | None = None - voice: str | None = None - text: str | None = None - streaming: bool | None = None + message_id: str | None = Field(default=None, description="Message ID") + voice: str | None = Field(default=None, description="Voice to use for TTS") + text: str | None = Field(default=None, description="Text to convert to audio") + streaming: bool | None = Field(default=None, description="Enable streaming response") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 6228cfc25b..907dd1b06d 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -3,10 +3,10 @@ import logging from flask import request from flask_restx import Resource from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.controller_schemas import TextToAudioPayload from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( @@ -86,13 +86,6 @@ class AudioApi(Resource): raise InternalServerError() -class TextToAudioPayload(BaseModel): - message_id: str | None = Field(default=None, description="Message ID") - voice: str | None = Field(default=None, description="Voice to use for TTS") - text: str | None = Field(default=None, description="Text to convert to audio") - streaming: bool | None = Field(default=None, description="Enable streaming response") - - register_schema_model(service_api_ns, TextToAudioPayload) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 25cb6b2b9e..39afdd843f 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -3,10 +3,10 @@ from typing import Literal from flask import request from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field, TypeAdapter, field_validator +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound -from controllers.common.controller_schemas import MessageFeedbackPayload +from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( @@ -25,7 +25,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper -from libs.helper import uuid_value from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -41,19 +40,6 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) -class MessageListQuery(BaseModel): - conversation_id: str = Field(description="Conversation UUID") - first_id: str | None = Field(default=None, description="First message ID for pagination") - limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") - - @field_validator("conversation_id", "first_id") - @classmethod - def validate_uuid(cls, value: str | None) -> str | None: - if value is None: - return value - return uuid_value(value) - - class MessageMoreLikeThisQuery(BaseModel): response_mode: Literal["blocking", "streaming"] = Field( description="Response mode", diff --git a/api/tests/unit_tests/controllers/web/test_pydantic_models.py b/api/tests/unit_tests/controllers/web/test_pydantic_models.py index dcf8133712..bceb65b89f 100644 --- a/api/tests/unit_tests/controllers/web/test_pydantic_models.py +++ b/api/tests/unit_tests/controllers/web/test_pydantic_models.py @@ -198,7 +198,7 @@ class TestMessageListQuery: assert q.limit == 20 def test_invalid_conversation_id(self) -> None: - with pytest.raises(ValidationError, match="not a valid uuid"): + with pytest.raises(ValidationError, match="must be a valid UUID"): MessageListQuery(conversation_id="bad") def test_limit_bounds(self) -> None: @@ -216,7 +216,7 @@ class TestMessageListQuery: def test_invalid_first_id(self) -> None: cid = str(uuid4()) - with pytest.raises(ValidationError, match="not a valid uuid"): + with pytest.raises(ValidationError, match="must be a valid UUID"): MessageListQuery(conversation_id=cid, first_id="invalid")