refactor(api): migrate console conversation variables response model to BaseModel (#35193)

Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
NVIDIAN 2026-04-14 12:51:33 -07:00 committed by GitHub
parent b65a5fcd97
commit b1722c8af9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 196 additions and 38 deletions

View File

@ -1,44 +1,86 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.conversation_variable_fields import (
conversation_variable_fields,
paginated_conversation_variable_fields,
)
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
# For nested models, need to replace nested dict with registered model
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
paginated_conversation_variable_fields_copy["data"] = fields.List(
fields.Nested(conversation_variable_model), attribute="data"
)
paginated_conversation_variable_model = console_ns.model(
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
return value
try:
return serialize_value_type(value)
except Exception:
return serialize_value_type({"value_type": value})
@field_validator("value", mode="before")
@classmethod
def _normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class PaginatedConversationVariableResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
console_ns,
ConversationVariablesQuery,
ConversationVariableResponse,
PaginatedConversationVariableResponse,
)
@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource):
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@console_ns.response(
200,
"Conversation variables retrieved successfully",
console_ns.models[PaginatedConversationVariableResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
rows = session.scalars(stmt).all()
return {
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}
response = PaginatedConversationVariableResponse.model_validate(
{
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
ConversationVariableResponse.model_validate(
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
)
for row in rows
],
}
)
return response.model_dump(mode="json")

View File

@ -0,0 +1,108 @@
from __future__ import annotations
from contextlib import nullcontext
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from graphon.variables.types import SegmentType
from pydantic import ValidationError
from controllers.console.app import conversation_variables as conversation_variables_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
created_at = datetime(2026, 1, 1, tzinfo=UTC)
updated_at = datetime(2026, 1, 2, tzinfo=UTC)
row = SimpleNamespace(
created_at=created_at,
updated_at=updated_at,
to_variable=lambda: SimpleNamespace(
model_dump=lambda: {
"id": "var-1",
"name": "my_var",
"value_type": "string",
"value": "value",
"description": "desc",
}
),
)
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
conversation_variables_module,
"sessionmaker",
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
)
with app.test_request_context(
"/console/api/apps/app-1/conversation-variables",
method="GET",
query_string={"conversation_id": "conv-1"},
):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response["page"] == 1
assert response["limit"] == 100
assert response["total"] == 1
assert response["has_more"] is False
assert response["data"][0]["id"] == "var-1"
assert response["data"][0]["created_at"] == int(created_at.timestamp())
assert response["data"][0]["updated_at"] == int(updated_at.timestamp())
def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
row = SimpleNamespace(
created_at=None,
updated_at=None,
to_variable=lambda: SimpleNamespace(
model_dump=lambda: {
"id": "var-2",
"name": "my_var_2",
"value_type": SegmentType.INTEGER,
"value": 42,
"description": None,
}
),
)
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
conversation_variables_module,
"sessionmaker",
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
)
with app.test_request_context(
"/console/api/apps/app-1/conversation-variables",
method="GET",
query_string={"conversation_id": "conv-1"},
):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response["data"][0]["value_type"] == "number"
assert response["data"][0]["value"] == "42"
def test_get_conversation_variables_requires_conversation_id(app) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"):
with pytest.raises(ValidationError):
method(app_model=SimpleNamespace(id="app-1"))