refactor(api): migrate service conversation-variable responses to BaseModel (#35205)

Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
NVIDIAN 2026-04-14 12:49:21 -07:00 committed by GitHub
parent f66a3c49c4
commit 800954f8ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 198 additions and 8 deletions

View File

@ -1,7 +1,9 @@
from datetime import datetime
from typing import Any, Literal
from flask import request
from flask_restx import Resource
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound
@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
SimpleConversation,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
value: Any
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
except ValueError:
return value
try:
return serialize_value_type(value)
except (AttributeError, TypeError, ValueError):
pass
try:
return serialize_value_type({"value_type": value})
except (AttributeError, TypeError, ValueError):
value_attr = getattr(value, "value", None)
if value_attr is not None:
return str(value_attr)
return str(value)
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
limit: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
service_api_ns,
ConversationListQuery,
ConversationRenamePayload,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
ConversationVariableResponse,
ConversationVariableInfiniteScrollPaginationResponse,
)
@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
404: "Conversation not found",
}
)
@service_api_ns.response(
200,
"Variables retrieved successfully",
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser, c_id):
"""List all variables for a conversation.
@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
last_id = str(query_args.last_id) if query_args.last_id else None
try:
return ConversationService.get_conversational_variable(
pagination = ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
)
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
pagination, from_attributes=True
).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
404: "Conversation or variable not found",
}
)
@service_api_ns.response(
200,
"Variable updated successfully",
service_api_ns.models[ConversationVariableResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
"""Update a conversation variable's value.
@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.update_conversation_variable(
variable = ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, payload.value
)
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationVariableNotExistsError:

View File

@ -15,10 +15,12 @@ Focus on:
import sys
import uuid
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from graphon.variables.types import SegmentType
from werkzeug.exceptions import BadRequest, NotFound
import services
@ -29,6 +31,8 @@ from controllers.service_api.app.conversation import (
ConversationRenameApi,
ConversationRenamePayload,
ConversationVariableDetailApi,
ConversationVariableInfiniteScrollPaginationResponse,
ConversationVariableResponse,
ConversationVariablesApi,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
@ -261,6 +265,46 @@ class TestConversationVariableUpdatePayload:
assert payload.value == nested
class TestConversationVariableResponseModels:
def test_variable_response_normalizes_value_type_and_timestamps(self):
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = ConversationVariableResponse.model_validate(
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SegmentType.INTEGER,
"value": 1,
"description": "desc",
"created_at": created_at,
"updated_at": created_at,
}
)
assert response.value_type == "number"
assert response.value == "1"
assert response.created_at == int(created_at.timestamp())
assert response.updated_at == int(created_at.timestamp())
def test_variable_pagination_response(self):
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
{
"limit": 1,
"has_more": False,
"data": [
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": "string",
"value": "bar",
}
],
}
)
assert response.limit == 1
assert response.has_more is False
assert len(response.data) == 1
assert response.data[0].name == "foo"
class TestConversationAppModeValidation:
"""Test app mode validation for conversation endpoints."""
@ -549,6 +593,44 @@ class TestConversationVariablesApiController:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,
"get_conversational_variable",
lambda *_args, **_kwargs: SimpleNamespace(
limit=1,
has_more=False,
data=[
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SegmentType.INTEGER,
"value": 1,
"created_at": created_at,
"updated_at": created_at,
}
],
),
)
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables?limit=20",
method="GET",
):
result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
assert result["limit"] == 1
assert result["has_more"] is False
assert result["data"][0]["value_type"] == "number"
assert result["data"][0]["value"] == "1"
assert result["data"][0]["created_at"] == int(created_at.timestamp())
class TestConversationVariableDetailApiController:
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -602,3 +684,41 @@ class TestConversationVariableDetailApiController:
c_id="00000000-0000-0000-0000-000000000001",
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
lambda *_args, **_kwargs: {
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SegmentType.INTEGER,
"value": 1,
"created_at": created_at,
"updated_at": created_at,
},
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables/2",
method="PUT",
json={"value": 1},
):
result = handler(
api,
app_model=app_model,
end_user=end_user,
c_id="00000000-0000-0000-0000-000000000001",
variable_id="00000000-0000-0000-0000-000000000002",
)
assert result["id"] == "550e8400-e29b-41d4-a716-446655440000"
assert result["value_type"] == "number"
assert result["value"] == "1"
assert result["created_at"] == int(created_at.timestamp())