diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 369c26a80c..cead33d14f 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,44 +1,86 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.conversation_variable_fields import ( - conversation_variable_fields, - paginated_conversation_variable_fields, -) +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from libs.login import login_required from models import ConversationVariable from models.model import AppMode -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ConversationVariablesQuery(BaseModel): conversation_id: str = Field(..., description="Conversation ID to filter variables") -console_ns.schema_model( - ConversationVariablesQuery.__name__, - ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base model first -conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields) -# For nested models, need to replace nested dict with registered model -paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy() -paginated_conversation_variable_fields_copy["data"] = fields.List( - fields.Nested(conversation_variable_model), attribute="data" -) -paginated_conversation_variable_model = console_ns.model( - "PaginatedConversationVariable", paginated_conversation_variable_fields_copy +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type().value) + if isinstance(value, str): + return value + try: + return serialize_value_type(value) + except Exception: + return serialize_value_type({"value_type": value}) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +register_schema_models( + console_ns, + ConversationVariablesQuery, + ConversationVariableResponse, + PaginatedConversationVariableResponse, ) @@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource): @console_ns.doc(description="Get conversation variables for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__]) - @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) + @console_ns.response( + 200, + "Conversation variables retrieved successfully", + console_ns.models[PaginatedConversationVariableResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) - @marshal_with(paginated_conversation_variable_model) def get(self, app_model): args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource): with sessionmaker(db.engine, expire_on_commit=False).begin() as session: rows = session.scalars(stmt).all() - return { - "page": page, - "limit": page_size, - "total": len(rows), - "has_more": False, - "data": [ - { - "created_at": row.created_at, - "updated_at": row.updated_at, - **row.to_variable().model_dump(), - } - for row in rows - ], - } + response = PaginatedConversationVariableResponse.model_validate( + { + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ + ConversationVariableResponse.model_validate( + { + "created_at": row.created_at, + "updated_at": row.updated_at, + **row.to_variable().model_dump(), + } + ) + for row in rows + ], + } + ) + return response.model_dump(mode="json") diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py new file mode 100644 index 0000000000..42b3420c31 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from contextlib import nullcontext +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest +from graphon.variables.types import SegmentType +from pydantic import ValidationError + +from controllers.console.app import conversation_variables as conversation_variables_module + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + created_at = datetime(2026, 1, 1, tzinfo=UTC) + updated_at = datetime(2026, 1, 2, tzinfo=UTC) + row = SimpleNamespace( + created_at=created_at, + updated_at=updated_at, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-1", + "name": "my_var", + "value_type": "string", + "value": "value", + "description": "desc", + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["page"] == 1 + assert response["limit"] == 100 + assert response["total"] == 1 + assert response["has_more"] is False + assert response["data"][0]["id"] == "var-1" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["updated_at"] == int(updated_at.timestamp()) + + +def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + row = SimpleNamespace( + created_at=None, + updated_at=None, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-2", + "name": "my_var_2", + "value_type": SegmentType.INTEGER, + "value": 42, + "description": None, + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["data"][0]["value_type"] == "number" + assert response["data"][0]["value"] == "42" + + +def test_get_conversation_variables_requires_conversation_id(app) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"): + with pytest.raises(ValidationError): + method(app_model=SimpleNamespace(id="app-1"))