mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
refactor(api): migrate service conversation-variable responses to BaseModel (#35205)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
parent
f66a3c49c4
commit
800954f8ce
@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
return serialize_value_type(value)
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
return serialize_value_type({"value_type": value})
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
value_attr = getattr(value, "value", None)
|
||||
if value_attr is not None:
|
||||
return str(value_attr)
|
||||
return str(value)
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def normalize_value(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[ConversationVariableResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
ConversationListQuery,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
ConversationVariableResponse,
|
||||
ConversationVariableInfiniteScrollPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
|
||||
404: "Conversation not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Variables retrieved successfully",
|
||||
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""List all variables for a conversation.
|
||||
|
||||
@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
pagination = ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||
)
|
||||
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||
pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
|
||||
404: "Conversation or variable not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Variable updated successfully",
|
||||
service_api_ns.models[ConversationVariableResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value.
|
||||
|
||||
@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
variable = ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||
|
||||
@ -15,10 +15,12 @@ Focus on:
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from graphon.variables.types import SegmentType
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
@ -29,6 +31,8 @@ from controllers.service_api.app.conversation import (
|
||||
ConversationRenameApi,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariableDetailApi,
|
||||
ConversationVariableInfiniteScrollPaginationResponse,
|
||||
ConversationVariableResponse,
|
||||
ConversationVariablesApi,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
@ -261,6 +265,46 @@ class TestConversationVariableUpdatePayload:
|
||||
assert payload.value == nested
|
||||
|
||||
|
||||
class TestConversationVariableResponseModels:
|
||||
def test_variable_response_normalizes_value_type_and_timestamps(self):
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
response = ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "foo",
|
||||
"value_type": SegmentType.INTEGER,
|
||||
"value": 1,
|
||||
"description": "desc",
|
||||
"created_at": created_at,
|
||||
"updated_at": created_at,
|
||||
}
|
||||
)
|
||||
assert response.value_type == "number"
|
||||
assert response.value == "1"
|
||||
assert response.created_at == int(created_at.timestamp())
|
||||
assert response.updated_at == int(created_at.timestamp())
|
||||
|
||||
def test_variable_pagination_response(self):
|
||||
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||
{
|
||||
"limit": 1,
|
||||
"has_more": False,
|
||||
"data": [
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "foo",
|
||||
"value_type": "string",
|
||||
"value": "bar",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
assert response.limit == 1
|
||||
assert response.has_more is False
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0].name == "foo"
|
||||
|
||||
|
||||
class TestConversationAppModeValidation:
|
||||
"""Test app mode validation for conversation endpoints."""
|
||||
|
||||
@ -549,6 +593,44 @@ class TestConversationVariablesApiController:
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"get_conversational_variable",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(
|
||||
limit=1,
|
||||
has_more=False,
|
||||
data=[
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "foo",
|
||||
"value_type": SegmentType.INTEGER,
|
||||
"value": 1,
|
||||
"created_at": created_at,
|
||||
"updated_at": created_at,
|
||||
}
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/variables?limit=20",
|
||||
method="GET",
|
||||
):
|
||||
result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
assert result["limit"] == 1
|
||||
assert result["has_more"] is False
|
||||
assert result["data"][0]["value_type"] == "number"
|
||||
assert result["data"][0]["value"] == "1"
|
||||
assert result["data"][0]["created_at"] == int(created_at.timestamp())
|
||||
|
||||
|
||||
class TestConversationVariableDetailApiController:
|
||||
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -602,3 +684,41 @@ class TestConversationVariableDetailApiController:
|
||||
c_id="00000000-0000-0000-0000-000000000001",
|
||||
variable_id="00000000-0000-0000-0000-000000000002",
|
||||
)
|
||||
|
||||
def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"update_conversation_variable",
|
||||
lambda *_args, **_kwargs: {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "foo",
|
||||
"value_type": SegmentType.INTEGER,
|
||||
"value": 1,
|
||||
"created_at": created_at,
|
||||
"updated_at": created_at,
|
||||
},
|
||||
)
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/variables/2",
|
||||
method="PUT",
|
||||
json={"value": 1},
|
||||
):
|
||||
result = handler(
|
||||
api,
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
c_id="00000000-0000-0000-0000-000000000001",
|
||||
variable_id="00000000-0000-0000-0000-000000000002",
|
||||
)
|
||||
|
||||
assert result["id"] == "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert result["value_type"] == "number"
|
||||
assert result["value"] == "1"
|
||||
assert result["created_at"] == int(created_at.timestamp())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user