mirror of https://github.com/langgenius/dify.git
refactor: more ns.model to BaseModel (#30445)
This commit is contained in:
parent
151101aaf5
commit
2cef879209
|
|
@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
|
|
@ -177,6 +176,12 @@ annotation_hit_history_model = console_ns.model(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -11,7 +10,11 @@ from controllers.console.explore.error import NotChatAppError
|
|||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
ResultResponse,
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
|
@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl
|
|||
endpoint="installed_app_conversations",
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
|
|
@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
|
|
@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource):
|
|||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=args.pinned,
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource):
|
|||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
|
@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource):
|
|||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
|
|
@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource):
|
|||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
return ConversationService.rename(
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, current_user, payload.name, payload.auto_generate
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource):
|
|||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
|
@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource):
|
|||
raise ValueError("current_user must be an Account instance")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import logging
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -23,7 +22,8 @@ from controllers.console.explore.wraps import InstalledAppResource
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
|
|
@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
|
|||
endpoint="installed_app_messages",
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource):
|
|||
args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.conversation_id),
|
||||
str(args.first_id) if args.first_id else None,
|
||||
args.limit,
|
||||
)
|
||||
adapter = TypeAdapter(MessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return MessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
|
|
@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
|
@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
|
@ -26,28 +26,8 @@ class SavedMessageCreatePayload(BaseModel):
|
|||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -57,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource):
|
|||
|
||||
args = SavedMessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(
|
||||
pagination = SavedMessageService.pagination_by_last_id(
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.last_id) if args.last_id else None,
|
||||
args.limit,
|
||||
)
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
|
|
@ -78,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
|
@ -96,4 +83,4 @@ class SavedMessageApi(InstalledAppResource):
|
|||
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ from uuid import UUID
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
|
|
@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
build_conversation_delete_model,
|
||||
build_conversation_infinite_scroll_pagination_model,
|
||||
build_simple_conversation_model,
|
||||
ConversationDelete,
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
|
|
@ -105,7 +104,6 @@ class ConversationApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""List all conversations for the current user.
|
||||
|
||||
|
|
@ -120,7 +118,7 @@ class ConversationApi(Resource):
|
|||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
pagination = ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
|
|
@ -129,6 +127,13 @@ class ConversationApi(Resource):
|
|||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=query_args.sort_by,
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
@ -146,7 +151,6 @@ class ConversationDetailApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -159,7 +163,7 @@ class ConversationDetailApi(Resource):
|
|||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
return ConversationDelete(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
|
@ -176,7 +180,6 @@ class ConversationRenameApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -188,7 +191,14 @@ class ConversationRenameApi(Resource):
|
|||
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, end_user, payload.name, payload.auto_generate
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns
|
|||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import build_message_file_model
|
||||
from fields.message_fields import build_agent_thought_model, build_feedback_model
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
|
|
@ -48,49 +45,6 @@ class FeedbackListQuery(BaseModel):
|
|||
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Namespace):
|
||||
"""Build the message model for the API or Namespace."""
|
||||
# First build the nested models
|
||||
feedback_model = build_feedback_model(api_or_ns)
|
||||
agent_thought_model = build_agent_thought_model(api_or_ns)
|
||||
message_file_model = build_message_file_model(api_or_ns)
|
||||
|
||||
# Then build the message fields with nested models
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.Raw(
|
||||
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
|
||||
if obj.message_metadata
|
||||
else []
|
||||
),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
return api_or_ns.model("Message", message_fields)
|
||||
|
||||
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the message infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested message model first
|
||||
message_model = build_message_model(api_or_ns)
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_model)),
|
||||
}
|
||||
return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/messages")
|
||||
class MessageListApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
|
||||
|
|
@ -104,7 +58,6 @@ class MessageListApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""List messages in a conversation.
|
||||
|
||||
|
|
@ -119,9 +72,16 @@ class MessageListApi(Resource):
|
|||
first_id = str(query_args.first_id) if query_args.first_id else None
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model, end_user, conversation_id, first_id, query_args.limit
|
||||
)
|
||||
adapter = TypeAdapter(MessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return MessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
|
|
@ -162,7 +122,7 @@ class MessageFeedbackApi(Resource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/app/feedbacks")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -8,7 +9,11 @@ from controllers.web.error import NotChatAppError
|
|||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
ResultResponse,
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
|
|
@ -54,7 +59,6 @@ class ConversationListApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -82,7 +86,7 @@ class ConversationListApi(WebApiResource):
|
|||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
|
|
@ -92,16 +96,19 @@ class ConversationListApi(WebApiResource):
|
|||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>")
|
||||
class ConversationApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Delete Conversation")
|
||||
@web_ns.doc(description="Delete a specific conversation.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
|
|
@ -115,7 +122,6 @@ class ConversationApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -126,7 +132,7 @@ class ConversationApi(WebApiResource):
|
|||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
|
@ -155,7 +161,6 @@ class ConversationRenameApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -171,17 +176,20 @@ class ConversationRenameApi(WebApiResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/pin")
|
||||
class ConversationPinApi(WebApiResource):
|
||||
pin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Pin Conversation")
|
||||
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
|
|
@ -195,7 +203,6 @@ class ConversationPinApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(pin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -208,15 +215,11 @@ class ConversationPinApi(WebApiResource):
|
|||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/unpin")
|
||||
class ConversationUnPinApi(WebApiResource):
|
||||
unpin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Unpin Conversation")
|
||||
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
|
|
@ -230,7 +233,6 @@ class ConversationUnPinApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(unpin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -239,4 +241,4 @@ class ConversationUnPinApi(WebApiResource):
|
|||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import logging
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -22,11 +21,10 @@ from controllers.web.wraps import WebApiResource
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
|
|
@ -70,29 +68,6 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message
|
|||
|
||||
@web_ns.route("/messages")
|
||||
class MessageListApi(WebApiResource):
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Message List")
|
||||
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
|
||||
@web_ns.doc(
|
||||
|
|
@ -121,7 +96,6 @@ class MessageListApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -131,9 +105,16 @@ class MessageListApi(WebApiResource):
|
|||
query = MessageListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model, end_user, query.conversation_id, query.first_id, query.limit
|
||||
)
|
||||
adapter = TypeAdapter(WebMessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return WebMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
|
|
@ -142,10 +123,6 @@ class MessageListApi(WebApiResource):
|
|||
|
||||
@web_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(WebApiResource):
|
||||
feedback_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Create Message Feedback")
|
||||
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
|
||||
|
|
@ -170,7 +147,6 @@ class MessageFeedbackApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(feedback_response_fields)
|
||||
def post(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
|
|
@ -187,7 +163,7 @@ class MessageFeedbackApi(WebApiResource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/messages/<uuid:message_id>/more-like-this")
|
||||
|
|
@ -247,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||
|
||||
@web_ns.route("/messages/<uuid:message_id>/suggested-questions")
|
||||
class MessageSuggestedQuestionApi(WebApiResource):
|
||||
suggested_questions_response_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Suggested Questions")
|
||||
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
|
||||
|
|
@ -264,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(suggested_questions_response_fields)
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -277,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
|
||||
)
|
||||
# questions is a list of strings, not a list of Message objects
|
||||
# so we can directly return it
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
|
|
@ -296,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -1,40 +1,20 @@
|
|||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import uuid_value
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages")
|
||||
class SavedMessageListApi(WebApiResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
post_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Saved Messages")
|
||||
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
|
||||
@web_ns.doc(
|
||||
|
|
@ -58,7 +38,6 @@ class SavedMessageListApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
|
@ -70,7 +49,14 @@ class SavedMessageListApi(WebApiResource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
|
||||
@web_ns.doc("Save Message")
|
||||
@web_ns.doc(description="Save a specific message for later reference.")
|
||||
|
|
@ -89,7 +75,6 @@ class SavedMessageListApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(post_response_fields)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
|
@ -102,15 +87,11 @@ class SavedMessageListApi(WebApiResource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages/<uuid:message_id>")
|
||||
class SavedMessageApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Delete Saved Message")
|
||||
@web_ns.doc(description="Remove a message from saved messages.")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
|
||||
|
|
@ -124,7 +105,6 @@ class SavedMessageApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
|
|
@ -133,4 +113,4 @@ class SavedMessageApi(WebApiResource):
|
|||
|
||||
SavedMessageService.delete(app_model, end_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
|
|
|||
|
|
@ -1,236 +1,338 @@
|
|||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from .raws import FilesContainedField
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.file import File
|
||||
|
||||
JSONValue: TypeAlias = Any
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
feedback_fields = {
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
}
|
||||
class MessageFile(ResponseModel):
|
||||
id: str
|
||||
filename: str
|
||||
type: str
|
||||
url: str | None = None
|
||||
mime_type: str | None = None
|
||||
size: int | None = None
|
||||
transfer_method: str
|
||||
belongs_to: str | None = None
|
||||
upload_file_id: str | None = None
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
message_file_fields = {
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
}
|
||||
@field_validator("transfer_method", mode="before")
|
||||
@classmethod
|
||||
def _normalize_transfer_method(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def build_message_file_model(api_or_ns: Namespace):
|
||||
"""Build the message file fields for the API or Namespace."""
|
||||
return api_or_ns.model("MessageFile", message_file_fields)
|
||||
class SimpleConversation(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputs: dict[str, JSONValue]
|
||||
status: str
|
||||
introduction: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_fields)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_fields, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
}
|
||||
|
||||
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
|
||||
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
|
||||
model_config_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"model": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"pre_prompt": fields.String,
|
||||
"agent_mode": fields.Raw,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
"model": fields.Raw(attribute="model_dict"),
|
||||
"pre_prompt": fields.String,
|
||||
}
|
||||
|
||||
simple_message_detail_fields = {
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": MessageTextField,
|
||||
"answer": fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String(),
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_fields, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_fields),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_fields), attribute="items"),
|
||||
}
|
||||
|
||||
conversation_message_detail_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"model_config": fields.Nested(model_config_fields),
|
||||
"message": fields.Nested(message_detail_fields, attribute="first_message"),
|
||||
}
|
||||
|
||||
conversation_with_summary_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"name": fields.String,
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_fields),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"status_count": fields.Nested(status_count_fields),
|
||||
}
|
||||
|
||||
conversation_with_summary_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
|
||||
}
|
||||
|
||||
conversation_detail_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_fields),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
}
|
||||
|
||||
simple_conversation_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"status": fields.String,
|
||||
"introduction": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
conversation_delete_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(simple_conversation_fields)),
|
||||
}
|
||||
class ConversationInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[SimpleConversation]
|
||||
|
||||
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
||||
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
||||
|
||||
copied_fields = conversation_infinite_scroll_pagination_fields.copy()
|
||||
copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
|
||||
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
||||
class ConversationDelete(ResponseModel):
|
||||
result: str
|
||||
|
||||
|
||||
def build_conversation_delete_model(api_or_ns: Namespace):
|
||||
"""Build the conversation delete model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
||||
class ResultResponse(ResponseModel):
|
||||
result: str
|
||||
|
||||
|
||||
def build_simple_conversation_model(api_or_ns: Namespace):
|
||||
"""Build the simple conversation model for the API or Namespace."""
|
||||
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
||||
class SimpleAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class Feedback(ResponseModel):
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account: SimpleAccount | None = None
|
||||
|
||||
|
||||
class Annotation(ResponseModel):
|
||||
id: str
|
||||
question: str | None = None
|
||||
content: str
|
||||
account: SimpleAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class AnnotationHitHistory(ResponseModel):
|
||||
annotation_id: str
|
||||
annotation_create_account: SimpleAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class AgentThought(ResponseModel):
|
||||
id: str
|
||||
chain_id: str | None = None
|
||||
message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id")
|
||||
message_id: str
|
||||
position: int
|
||||
thought: str | None = None
|
||||
tool: str | None = None
|
||||
tool_labels: JSONValue
|
||||
tool_input: str | None = None
|
||||
created_at: int | None = None
|
||||
observation: str | None = None
|
||||
files: list[str]
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _fallback_chain_id(self):
|
||||
if self.chain_id is None and self.message_chain_id:
|
||||
self.chain_id = self.message_chain_id
|
||||
return self
|
||||
|
||||
|
||||
class MessageDetail(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: JSONValue
|
||||
message_tokens: int
|
||||
answer: str
|
||||
answer_tokens: int
|
||||
provider_response_latency: float
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
feedbacks: list[Feedback]
|
||||
workflow_run_id: str | None = None
|
||||
annotation: Annotation | None = None
|
||||
annotation_hit_history: AnnotationHitHistory | None = None
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought]
|
||||
message_files: list[MessageFile]
|
||||
metadata: JSONValue
|
||||
status: str
|
||||
error: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class FeedbackStat(ResponseModel):
|
||||
like: int
|
||||
dislike: int
|
||||
|
||||
|
||||
class StatusCount(ResponseModel):
|
||||
success: int
|
||||
failed: int
|
||||
partial_success: int
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: JSONValue | None = None
|
||||
model: JSONValue | None = None
|
||||
user_input_form: JSONValue | None = None
|
||||
pre_prompt: str | None = None
|
||||
agent_mode: JSONValue | None = None
|
||||
|
||||
|
||||
class SimpleModelConfig(ResponseModel):
|
||||
model: JSONValue | None = None
|
||||
pre_prompt: str | None = None
|
||||
|
||||
|
||||
class SimpleMessageDetail(ResponseModel):
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: str
|
||||
answer: str
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
|
||||
class Conversation(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_end_user_session_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
from_account_name: str | None = None
|
||||
read_at: int | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotation: Annotation | None = None
|
||||
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
message: SimpleMessageDetail | None = None
|
||||
|
||||
|
||||
class ConversationPagination(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[Conversation]
|
||||
|
||||
|
||||
class ConversationMessageDetail(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: int | None = None
|
||||
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
|
||||
message: MessageDetail | None = None
|
||||
|
||||
|
||||
class ConversationWithSummary(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_end_user_session_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
from_account_name: str | None = None
|
||||
name: str
|
||||
summary: str
|
||||
read_at: int | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotated: bool
|
||||
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
|
||||
message_count: int
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
status_count: StatusCount | None = None
|
||||
|
||||
|
||||
class ConversationWithSummaryPagination(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[ConversationWithSummary]
|
||||
|
||||
|
||||
class ConversationDetail(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotated: bool
|
||||
introduction: str | None = None
|
||||
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
|
||||
message_count: int
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
|
||||
|
||||
def to_timestamp(value: datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def format_files_contained(value: JSONValue) -> JSONValue:
|
||||
if isinstance(value, File):
|
||||
return value.model_dump()
|
||||
if isinstance(value, dict):
|
||||
return {k: format_files_contained(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [format_files_contained(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def message_text(value: JSONValue) -> str:
|
||||
if isinstance(value, list) and value:
|
||||
first = value[0]
|
||||
if isinstance(first, dict):
|
||||
text = first.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return ""
|
||||
|
||||
|
||||
def extract_model_config(value: object | None) -> dict[str, JSONValue]:
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if hasattr(value, "to_dict"):
|
||||
return value.to_dict()
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,77 +1,137 @@
|
|||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
from typing import TypeAlias
|
||||
|
||||
from .raws import FilesContainedField
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
feedback_fields = {
|
||||
"rating": fields.String,
|
||||
}
|
||||
from core.file import File
|
||||
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
||||
|
||||
JSONValueType: TypeAlias = JSONValue
|
||||
|
||||
|
||||
def build_feedback_model(api_or_ns: Namespace):
|
||||
"""Build the feedback model for the API or Namespace."""
|
||||
return api_or_ns.model("Feedback", feedback_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
}
|
||||
class SimpleFeedback(ResponseModel):
|
||||
rating: str | None = None
|
||||
|
||||
|
||||
def build_agent_thought_model(api_or_ns: Namespace):
|
||||
"""Build the agent thought model for the API or Namespace."""
|
||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||
class RetrieverResource(ResponseModel):
|
||||
id: str
|
||||
message_id: str
|
||||
position: int
|
||||
dataset_id: str | None = None
|
||||
dataset_name: str | None = None
|
||||
document_id: str | None = None
|
||||
document_name: str | None = None
|
||||
data_source_type: str | None = None
|
||||
segment_id: str | None = None
|
||||
score: float | None = None
|
||||
hit_count: int | None = None
|
||||
word_count: int | None = None
|
||||
segment_position: int | None = None
|
||||
index_node_hash: str | None = None
|
||||
content: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
class MessageListItem(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
parent_message_id: str | None = None
|
||||
inputs: dict[str, JSONValueType]
|
||||
query: str
|
||||
answer: str = Field(validation_alias="re_sign_file_url_answer")
|
||||
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
|
||||
retriever_resources: list[RetrieverResource]
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought]
|
||||
message_files: list[MessageFile]
|
||||
status: str
|
||||
error: str | None = None
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
|
||||
return format_files_contained(value)
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class WebMessageListItem(MessageListItem):
|
||||
metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
|
||||
|
||||
|
||||
class MessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[MessageListItem]
|
||||
|
||||
|
||||
class WebMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[WebMessageListItem]
|
||||
|
||||
|
||||
class SavedMessageItem(ResponseModel):
|
||||
id: str
|
||||
inputs: dict[str, JSONValueType]
|
||||
query: str
|
||||
answer: str
|
||||
message_files: list[MessageFile]
|
||||
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class SavedMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[SavedMessageItem]
|
||||
|
||||
|
||||
class SuggestedQuestionsResponse(ResponseModel):
|
||||
data: list[str]
|
||||
|
||||
|
||||
def to_timestamp(value: datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def format_files_contained(value: JSONValueType) -> JSONValueType:
|
||||
if isinstance(value, File):
|
||||
return value.model_dump()
|
||||
if isinstance(value, dict):
|
||||
return {k: format_files_contained(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [format_files_contained(v) for v in value]
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -0,0 +1,174 @@
|
|||
"""Unit tests for controllers.web.message message list mapping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from datetime import datetime
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_controller_module():
|
||||
"""Import controllers.web.message using a stub package."""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
parent_module_name = "controllers.web"
|
||||
module_name = f"{parent_module_name}.message"
|
||||
|
||||
if parent_module_name not in sys.modules:
|
||||
from flask_restx import Namespace
|
||||
|
||||
stub = ModuleType(parent_module_name)
|
||||
stub.__file__ = "controllers/web/__init__.py"
|
||||
stub.__path__ = ["controllers/web"]
|
||||
stub.__package__ = "controllers"
|
||||
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
|
||||
stub.web_ns = Namespace("web", description="Web API", path="/")
|
||||
sys.modules[parent_module_name] = stub
|
||||
|
||||
wraps_module_name = f"{parent_module_name}.wraps"
|
||||
if wraps_module_name not in sys.modules:
|
||||
wraps_stub = ModuleType(wraps_module_name)
|
||||
|
||||
class WebApiResource:
|
||||
pass
|
||||
|
||||
wraps_stub.WebApiResource = WebApiResource
|
||||
sys.modules[wraps_module_name] = wraps_stub
|
||||
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
message_module = _load_controller_module()
|
||||
MessageListApi = message_module.MessageListApi
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
def test_message_list_mapping(app: Flask) -> None:
|
||||
conversation_id = str(uuid4())
|
||||
message_id = str(uuid4())
|
||||
|
||||
created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
resource_created_at = datetime(2024, 1, 1, 13, 0, 0)
|
||||
thought_created_at = datetime(2024, 1, 1, 14, 0, 0)
|
||||
|
||||
retriever_resource_obj = SimpleNamespace(
|
||||
id="res-obj",
|
||||
message_id=message_id,
|
||||
position=2,
|
||||
dataset_id="ds-1",
|
||||
dataset_name="dataset",
|
||||
document_id="doc-1",
|
||||
document_name="document",
|
||||
data_source_type="file",
|
||||
segment_id="seg-1",
|
||||
score=0.9,
|
||||
hit_count=1,
|
||||
word_count=10,
|
||||
segment_position=0,
|
||||
index_node_hash="hash",
|
||||
content="content",
|
||||
created_at=resource_created_at,
|
||||
)
|
||||
|
||||
agent_thought = SimpleNamespace(
|
||||
id="thought-1",
|
||||
chain_id=None,
|
||||
message_chain_id="chain-1",
|
||||
message_id=message_id,
|
||||
position=1,
|
||||
thought="thinking",
|
||||
tool="tool",
|
||||
tool_labels={"label": "value"},
|
||||
tool_input="{}",
|
||||
created_at=thought_created_at,
|
||||
observation="observed",
|
||||
files=["file-a"],
|
||||
)
|
||||
|
||||
message_file_obj = SimpleNamespace(
|
||||
id="file-obj",
|
||||
filename="b.txt",
|
||||
type="file",
|
||||
url=None,
|
||||
mime_type=None,
|
||||
size=None,
|
||||
transfer_method="local",
|
||||
belongs_to=None,
|
||||
upload_file_id=None,
|
||||
)
|
||||
|
||||
message = SimpleNamespace(
|
||||
id=message_id,
|
||||
conversation_id=conversation_id,
|
||||
parent_message_id=None,
|
||||
inputs={"foo": "bar"},
|
||||
query="hello",
|
||||
re_sign_file_url_answer="answer",
|
||||
user_feedback=SimpleNamespace(rating="like"),
|
||||
retriever_resources=[
|
||||
{"id": "res-dict", "message_id": message_id, "position": 1},
|
||||
retriever_resource_obj,
|
||||
],
|
||||
created_at=created_at,
|
||||
agent_thoughts=[agent_thought],
|
||||
message_files=[
|
||||
{"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"},
|
||||
message_file_obj,
|
||||
],
|
||||
status="success",
|
||||
error=None,
|
||||
message_metadata_dict={"meta": "value"},
|
||||
)
|
||||
|
||||
pagination = SimpleNamespace(limit=20, has_more=False, data=[message])
|
||||
app_model = SimpleNamespace(mode="chat")
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with (
|
||||
patch.object(message_module.MessageService, "pagination_by_first_id", return_value=pagination) as mock_page,
|
||||
app.test_request_context(f"/messages?conversation_id={conversation_id}&limit=20"),
|
||||
):
|
||||
response = MessageListApi().get(app_model, end_user)
|
||||
|
||||
mock_page.assert_called_once_with(app_model, end_user, conversation_id, None, 20)
|
||||
assert response["limit"] == 20
|
||||
assert response["has_more"] is False
|
||||
assert len(response["data"]) == 1
|
||||
|
||||
item = response["data"][0]
|
||||
assert item["id"] == message_id
|
||||
assert item["conversation_id"] == conversation_id
|
||||
assert item["inputs"] == {"foo": "bar"}
|
||||
assert item["answer"] == "answer"
|
||||
assert item["feedback"]["rating"] == "like"
|
||||
assert item["metadata"] == {"meta": "value"}
|
||||
assert item["created_at"] == int(created_at.timestamp())
|
||||
|
||||
assert item["retriever_resources"][0]["id"] == "res-dict"
|
||||
assert item["retriever_resources"][1]["id"] == "res-obj"
|
||||
assert item["retriever_resources"][1]["created_at"] == int(resource_created_at.timestamp())
|
||||
|
||||
assert item["agent_thoughts"][0]["chain_id"] == "chain-1"
|
||||
assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp())
|
||||
|
||||
assert item["message_files"][0]["id"] == "file-dict"
|
||||
assert item["message_files"][1]["id"] == "file-obj"
|
||||
Loading…
Reference in New Issue