refactor(api): deduplicate TextToAudioPayload and MessageListQuery into controller_schemas.py (#34757)

This commit is contained in:
Jake Armstrong 2026-04-13 07:27:35 +02:00 committed by GitHub
parent 7dd507af04
commit 9121f24181
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 33 deletions

View File

@ -23,9 +23,9 @@ class ConversationRenamePayload(BaseModel):
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID")
first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
class MessageFeedbackPayload(BaseModel):
@ -73,7 +73,7 @@ class WorkflowUpdatePayload(BaseModel):
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")

View File

@ -3,10 +3,10 @@ import logging
from flask import request
from flask_restx import Resource
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@ -86,13 +86,6 @@ class AudioApi(Resource):
raise InternalServerError()
class TextToAudioPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(service_api_ns, TextToAudioPayload)

View File

@ -3,10 +3,10 @@ from typing import Literal
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@ -25,7 +25,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
from libs.helper import uuid_value
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -41,19 +40,6 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: str = Field(description="Conversation UUID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",

View File

@ -198,7 +198,7 @@ class TestMessageListQuery:
assert q.limit == 20
def test_invalid_conversation_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
with pytest.raises(ValidationError, match="must be a valid UUID"):
MessageListQuery(conversation_id="bad")
def test_limit_bounds(self) -> None:
@ -216,7 +216,7 @@ class TestMessageListQuery:
def test_invalid_first_id(self) -> None:
cid = str(uuid4())
with pytest.raises(ValidationError, match="not a valid uuid"):
with pytest.raises(ValidationError, match="must be a valid UUID"):
MessageListQuery(conversation_id=cid, first_id="invalid")