mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
refactor(api): migrate console conversation variables response model to BaseModel (#35193)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
parent
b65a5fcd97
commit
b1722c8af9
@ -1,44 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
paginated_conversation_variable_fields,
|
||||
)
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
|
||||
paginated_conversation_variable_fields_copy["data"] = fields.List(
|
||||
fields.Nested(conversation_variable_model), attribute="data"
|
||||
)
|
||||
paginated_conversation_variable_model = console_ns.model(
|
||||
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return serialize_value_type(value)
|
||||
except Exception:
|
||||
return serialize_value_type({"value_type": value})
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def _normalize_value(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class PaginatedConversationVariableResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[ConversationVariableResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableResponse,
|
||||
PaginatedConversationVariableResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource):
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Conversation variables retrieved successfully",
|
||||
console_ns.models[PaginatedConversationVariableResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource):
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
response = PaginatedConversationVariableResponse.model_validate(
|
||||
{
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.app import conversation_variables as conversation_variables_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
updated_at = datetime(2026, 1, 2, tzinfo=UTC)
|
||||
row = SimpleNamespace(
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
to_variable=lambda: SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"id": "var-1",
|
||||
"name": "my_var",
|
||||
"value_type": "string",
|
||||
"value": "value",
|
||||
"description": "desc",
|
||||
}
|
||||
),
|
||||
)
|
||||
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
|
||||
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
conversation_variables_module,
|
||||
"sessionmaker",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/conversation-variables",
|
||||
method="GET",
|
||||
query_string={"conversation_id": "conv-1"},
|
||||
):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response["page"] == 1
|
||||
assert response["limit"] == 100
|
||||
assert response["total"] == 1
|
||||
assert response["has_more"] is False
|
||||
assert response["data"][0]["id"] == "var-1"
|
||||
assert response["data"][0]["created_at"] == int(created_at.timestamp())
|
||||
assert response["data"][0]["updated_at"] == int(updated_at.timestamp())
|
||||
|
||||
|
||||
def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
row = SimpleNamespace(
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
to_variable=lambda: SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"id": "var-2",
|
||||
"name": "my_var_2",
|
||||
"value_type": SegmentType.INTEGER,
|
||||
"value": 42,
|
||||
"description": None,
|
||||
}
|
||||
),
|
||||
)
|
||||
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
|
||||
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
conversation_variables_module,
|
||||
"sessionmaker",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/conversation-variables",
|
||||
method="GET",
|
||||
query_string={"conversation_id": "conv-1"},
|
||||
):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response["data"][0]["value_type"] == "number"
|
||||
assert response["data"][0]["value"] == "42"
|
||||
|
||||
|
||||
def test_get_conversation_variables_requires_conversation_id(app) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"):
|
||||
with pytest.raises(ValidationError):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
Loading…
Reference in New Issue
Block a user