mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:06:56 +08:00
refactor(api): migrate service conversation-variable responses to BaseModel (#35205)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
parent
f66a3c49c4
commit
800954f8ce
@ -1,7 +1,9 @@
|
|||||||
|
from datetime import datetime
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
|
from graphon.variables.types import SegmentType
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
|
|||||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from fields._value_type_serializer import serialize_value_type
|
||||||
|
from fields.base import ResponseModel
|
||||||
from fields.conversation_fields import (
|
from fields.conversation_fields import (
|
||||||
ConversationInfiniteScrollPagination,
|
ConversationInfiniteScrollPagination,
|
||||||
SimpleConversation,
|
SimpleConversation,
|
||||||
)
|
)
|
||||||
from fields.conversation_variable_fields import (
|
|
||||||
build_conversation_variable_infinite_scroll_pagination_model,
|
|
||||||
build_conversation_variable_model,
|
|
||||||
)
|
|
||||||
from libs.helper import UUIDStrOrEmpty
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
|
|||||||
value: Any
|
value: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableResponse(ResponseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
value_type: str
|
||||||
|
value: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
created_at: int | None = None
|
||||||
|
updated_at: int | None = None
|
||||||
|
|
||||||
|
@field_validator("value_type", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_value_type(cls, value: Any) -> str:
|
||||||
|
exposed_type = getattr(value, "exposed_type", None)
|
||||||
|
if callable(exposed_type):
|
||||||
|
return str(exposed_type().value)
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return str(SegmentType(value).exposed_type().value)
|
||||||
|
except ValueError:
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
return serialize_value_type(value)
|
||||||
|
except (AttributeError, TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return serialize_value_type({"value_type": value})
|
||||||
|
except (AttributeError, TypeError, ValueError):
|
||||||
|
value_attr = getattr(value, "value", None)
|
||||||
|
if value_attr is not None:
|
||||||
|
return str(value_attr)
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@field_validator("value", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_value(cls, value: Any | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@field_validator("created_at", "updated_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return int(value.timestamp())
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
|
||||||
|
limit: int
|
||||||
|
has_more: bool
|
||||||
|
data: list[ConversationVariableResponse]
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(
|
register_schema_models(
|
||||||
service_api_ns,
|
service_api_ns,
|
||||||
ConversationListQuery,
|
ConversationListQuery,
|
||||||
ConversationRenamePayload,
|
ConversationRenamePayload,
|
||||||
ConversationVariablesQuery,
|
ConversationVariablesQuery,
|
||||||
ConversationVariableUpdatePayload,
|
ConversationVariableUpdatePayload,
|
||||||
|
ConversationVariableResponse,
|
||||||
|
ConversationVariableInfiniteScrollPaginationResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
|
|||||||
404: "Conversation not found",
|
404: "Conversation not found",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@service_api_ns.response(
|
||||||
|
200,
|
||||||
|
"Variables retrieved successfully",
|
||||||
|
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
|
||||||
|
)
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||||
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
|
|
||||||
def get(self, app_model: App, end_user: EndUser, c_id):
|
def get(self, app_model: App, end_user: EndUser, c_id):
|
||||||
"""List all variables for a conversation.
|
"""List all variables for a conversation.
|
||||||
|
|
||||||
@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
|
|||||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.get_conversational_variable(
|
pagination = ConversationService.get_conversational_variable(
|
||||||
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||||
)
|
)
|
||||||
|
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||||
|
pagination, from_attributes=True
|
||||||
|
).model_dump(mode="json")
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
|
|||||||
404: "Conversation or variable not found",
|
404: "Conversation or variable not found",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@service_api_ns.response(
|
||||||
|
200,
|
||||||
|
"Variable updated successfully",
|
||||||
|
service_api_ns.models[ConversationVariableResponse.__name__],
|
||||||
|
)
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||||
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
|
|
||||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||||
"""Update a conversation variable's value.
|
"""Update a conversation variable's value.
|
||||||
|
|
||||||
@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
|
|||||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.update_conversation_variable(
|
variable = ConversationService.update_conversation_variable(
|
||||||
app_model, conversation_id, variable_id, end_user, payload.value
|
app_model, conversation_id, variable_id, end_user, payload.value
|
||||||
)
|
)
|
||||||
|
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||||
|
|||||||
@ -15,10 +15,12 @@ Focus on:
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from graphon.variables.types import SegmentType
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -29,6 +31,8 @@ from controllers.service_api.app.conversation import (
|
|||||||
ConversationRenameApi,
|
ConversationRenameApi,
|
||||||
ConversationRenamePayload,
|
ConversationRenamePayload,
|
||||||
ConversationVariableDetailApi,
|
ConversationVariableDetailApi,
|
||||||
|
ConversationVariableInfiniteScrollPaginationResponse,
|
||||||
|
ConversationVariableResponse,
|
||||||
ConversationVariablesApi,
|
ConversationVariablesApi,
|
||||||
ConversationVariablesQuery,
|
ConversationVariablesQuery,
|
||||||
ConversationVariableUpdatePayload,
|
ConversationVariableUpdatePayload,
|
||||||
@ -261,6 +265,46 @@ class TestConversationVariableUpdatePayload:
|
|||||||
assert payload.value == nested
|
assert payload.value == nested
|
||||||
|
|
||||||
|
|
||||||
|
class TestConversationVariableResponseModels:
|
||||||
|
def test_variable_response_normalizes_value_type_and_timestamps(self):
|
||||||
|
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||||
|
response = ConversationVariableResponse.model_validate(
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"name": "foo",
|
||||||
|
"value_type": SegmentType.INTEGER,
|
||||||
|
"value": 1,
|
||||||
|
"description": "desc",
|
||||||
|
"created_at": created_at,
|
||||||
|
"updated_at": created_at,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.value_type == "number"
|
||||||
|
assert response.value == "1"
|
||||||
|
assert response.created_at == int(created_at.timestamp())
|
||||||
|
assert response.updated_at == int(created_at.timestamp())
|
||||||
|
|
||||||
|
def test_variable_pagination_response(self):
|
||||||
|
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||||
|
{
|
||||||
|
"limit": 1,
|
||||||
|
"has_more": False,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"name": "foo",
|
||||||
|
"value_type": "string",
|
||||||
|
"value": "bar",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.limit == 1
|
||||||
|
assert response.has_more is False
|
||||||
|
assert len(response.data) == 1
|
||||||
|
assert response.data[0].name == "foo"
|
||||||
|
|
||||||
|
|
||||||
class TestConversationAppModeValidation:
|
class TestConversationAppModeValidation:
|
||||||
"""Test app mode validation for conversation endpoints."""
|
"""Test app mode validation for conversation endpoints."""
|
||||||
|
|
||||||
@ -549,6 +593,44 @@ class TestConversationVariablesApiController:
|
|||||||
with pytest.raises(NotFound):
|
with pytest.raises(NotFound):
|
||||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||||
|
|
||||||
|
def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ConversationService,
|
||||||
|
"get_conversational_variable",
|
||||||
|
lambda *_args, **_kwargs: SimpleNamespace(
|
||||||
|
limit=1,
|
||||||
|
has_more=False,
|
||||||
|
data=[
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"name": "foo",
|
||||||
|
"value_type": SegmentType.INTEGER,
|
||||||
|
"value": 1,
|
||||||
|
"created_at": created_at,
|
||||||
|
"updated_at": created_at,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
api = ConversationVariablesApi()
|
||||||
|
handler = _unwrap(api.get)
|
||||||
|
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||||
|
end_user = SimpleNamespace()
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/conversations/1/variables?limit=20",
|
||||||
|
method="GET",
|
||||||
|
):
|
||||||
|
result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||||
|
|
||||||
|
assert result["limit"] == 1
|
||||||
|
assert result["has_more"] is False
|
||||||
|
assert result["data"][0]["value_type"] == "number"
|
||||||
|
assert result["data"][0]["value"] == "1"
|
||||||
|
assert result["data"][0]["created_at"] == int(created_at.timestamp())
|
||||||
|
|
||||||
|
|
||||||
class TestConversationVariableDetailApiController:
|
class TestConversationVariableDetailApiController:
|
||||||
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
@ -602,3 +684,41 @@ class TestConversationVariableDetailApiController:
|
|||||||
c_id="00000000-0000-0000-0000-000000000001",
|
c_id="00000000-0000-0000-0000-000000000001",
|
||||||
variable_id="00000000-0000-0000-0000-000000000002",
|
variable_id="00000000-0000-0000-0000-000000000002",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ConversationService,
|
||||||
|
"update_conversation_variable",
|
||||||
|
lambda *_args, **_kwargs: {
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"name": "foo",
|
||||||
|
"value_type": SegmentType.INTEGER,
|
||||||
|
"value": 1,
|
||||||
|
"created_at": created_at,
|
||||||
|
"updated_at": created_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
api = ConversationVariableDetailApi()
|
||||||
|
handler = _unwrap(api.put)
|
||||||
|
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||||
|
end_user = SimpleNamespace()
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/conversations/1/variables/2",
|
||||||
|
method="PUT",
|
||||||
|
json={"value": 1},
|
||||||
|
):
|
||||||
|
result = handler(
|
||||||
|
api,
|
||||||
|
app_model=app_model,
|
||||||
|
end_user=end_user,
|
||||||
|
c_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
variable_id="00000000-0000-0000-0000-000000000002",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["id"] == "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
assert result["value_type"] == "number"
|
||||||
|
assert result["value"] == "1"
|
||||||
|
assert result["created_at"] == int(created_at.timestamp())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user