mirror of https://github.com/langgenius/dify.git
feat: chat messages api support parent message id
This commit is contained in:
parent
fb14644a79
commit
df09acb74b
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -33,8 +33,11 @@ from libs import helper
|
|||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -53,14 +56,18 @@ class ChatRequestPayload(BaseModel):
|
|||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message UUID")
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@field_validator("conversation_id", "parent_message_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
|
||||
def normalize_uuid_fields(cls, value: str | UUID | None, info: ValidationInfo) -> str | None:
|
||||
"""Allow missing or blank UUID fields; enforce UUID format when provided."""
|
||||
if isinstance(value, UUID):
|
||||
return str(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
|
||||
|
|
@ -70,7 +77,36 @@ class ChatRequestPayload(BaseModel):
|
|||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
raise ValueError(f"{info.field_name} must be a valid UUID") from exc
|
||||
|
||||
|
||||
def _validate_parent_message_request(
|
||||
*,
|
||||
app_model: App,
|
||||
end_user: EndUser,
|
||||
conversation_id: str | None,
|
||||
parent_message_id: str | None,
|
||||
) -> None:
|
||||
if not parent_message_id:
|
||||
return
|
||||
|
||||
if not conversation_id:
|
||||
raise BadRequest("conversation_id is required when parent_message_id is provided.")
|
||||
|
||||
try:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=end_user
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
try:
|
||||
parent_message = MessageService.get_message(app_model=app_model, user=end_user, message_id=parent_message_id)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
if parent_message.conversation_id != conversation.id:
|
||||
raise BadRequest("parent_message_id does not belong to the conversation.")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
|
||||
|
|
@ -205,6 +241,13 @@ class ChatApi(Resource):
|
|||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id=args.get("conversation_id"),
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from sqlalchemy.orm import Session, sessionmaker
|
|||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
|
|
@ -168,7 +167,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from flask import Flask, current_app
|
|||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
|
|
@ -163,7 +162,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from flask import Flask, copy_current_request_context, current_app
|
|||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
|
|
@ -156,7 +155,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Union, cast
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
|
|
@ -84,6 +85,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
def _resolve_parent_message_id(self, args: Mapping[str, Any], invoke_from: InvokeFrom) -> str | None:
|
||||
parent_message_id = args.get("parent_message_id")
|
||||
if invoke_from == InvokeFrom.SERVICE_API and not parent_message_id:
|
||||
return UUID_NIL
|
||||
return parent_message_id
|
||||
|
||||
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
stmt = select(AppModelConfig).where(
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@ from collections.abc import Mapping, Sequence
|
|||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
|
|
@ -158,20 +157,12 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
|
|||
parent_message_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
|
||||
"For service API, we need to ensure its forward compatibility, "
|
||||
"so passing in the parent_message_id as request arg is not supported for now. "
|
||||
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
|
||||
"Starting from v0.9.0, parent_message_id is used to support message regeneration "
|
||||
"and branching in chat APIs."
|
||||
"For service API, when it is omitted, the system treats it as UUID_NIL to preserve legacy linear history."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("parent_message_id")
|
||||
@classmethod
|
||||
def validate_parent_message_id(cls, v, info: ValidationInfo):
|
||||
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
|
||||
raise ValueError("parent_message_id should be UUID_NIL for service API")
|
||||
return v
|
||||
|
||||
|
||||
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.service_api.app.completion import _validate_parent_message_request
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def test_validate_parent_message_skips_when_missing():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
|
||||
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
|
||||
):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id=None
|
||||
)
|
||||
|
||||
get_conversation.assert_not_called()
|
||||
get_message.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_parent_message_requires_conversation_id():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id="parent-id"
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_missing_conversation_raises_not_found():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
|
||||
with patch(
|
||||
"controllers.service_api.app.completion.ConversationService.get_conversation",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id="conversation-id",
|
||||
parent_message_id="parent-id",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_missing_message_raises_not_found():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
conversation = SimpleNamespace(id="conversation-id")
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
|
||||
patch(
|
||||
"controllers.service_api.app.completion.MessageService.get_message",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id="conversation-id",
|
||||
parent_message_id="parent-id",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_mismatch_conversation_raises_bad_request():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
conversation = SimpleNamespace(id="conversation-id")
|
||||
message = SimpleNamespace(conversation_id="different-id")
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
|
||||
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id="conversation-id",
|
||||
parent_message_id="parent-id",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_matches_conversation():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
conversation = SimpleNamespace(id="conversation-id")
|
||||
message = SimpleNamespace(conversation_id="conversation-id")
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
|
||||
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
|
||||
):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id="conversation-id",
|
||||
parent_message_id="parent-id",
|
||||
)
|
||||
|
|
@ -23,3 +23,24 @@ def test_chat_request_payload_validates_uuid():
|
|||
def test_chat_request_payload_rejects_invalid_uuid():
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})
|
||||
|
||||
|
||||
def test_chat_request_payload_accepts_blank_parent_message_id():
|
||||
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": ""})
|
||||
|
||||
assert payload.parent_message_id is None
|
||||
|
||||
|
||||
def test_chat_request_payload_validates_parent_message_id_uuid():
|
||||
parent_message_id = str(uuid.uuid4())
|
||||
|
||||
payload = ChatRequestPayload.model_validate(
|
||||
{"inputs": {}, "query": "hello", "parent_message_id": parent_message_id}
|
||||
)
|
||||
|
||||
assert payload.parent_message_id == parent_message_id
|
||||
|
||||
|
||||
def test_chat_request_payload_rejects_invalid_parent_message_id():
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": "invalid"})
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id.
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
Parent message ID to continue from a specific message or regenerate. Requires `conversation_id` and must belong to that conversation.
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision/Video capability.
|
||||
- `type` (string) Supported type:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
会話ID、以前のチャット記録に基づいて会話を続けるには、以前のメッセージのconversation_idを渡す必要があります。
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
特定のメッセージから続けたり再生成するための親メッセージID。`conversation_id` が必須で、その会話に属している必要があります。
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
ファイルリスト、モデルが Vision/Video 機能をサポートしている場合に限り、ファイルをテキスト理解および質問応答に組み合わせて入力するのに適しています。
|
||||
- `type` (string) サポートされるタイプ:
|
||||
|
|
|
|||
|
|
@ -54,6 +54,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
(选填)会话 ID,需要基于之前的聊天记录继续对话,必须传之前消息的 conversation_id。
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
用于从特定消息继续或重新生成的父消息 ID。需要提供 `conversation_id`,且必须属于该对话。
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision/Video 能力时可用。
|
||||
- `type` (string) 支持类型:
|
||||
|
|
|
|||
|
|
@ -55,6 +55,9 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id.
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
Parent message ID to continue from a specific message or regenerate. Requires `conversation_id` and must belong to that conversation.
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision/Video capability.
|
||||
- `type` (string) Supported type:
|
||||
|
|
|
|||
|
|
@ -55,6 +55,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
会話ID、以前のチャット記録に基づいて会話を続けるには、前のメッセージのconversation_idを渡す必要があります。
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
特定のメッセージから続けたり再生成するための親メッセージID。`conversation_id` が必須で、その会話に属している必要があります。
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
ファイルリスト、モデルが Vision/Video 機能をサポートしている場合に限り、ファイルをテキスト理解および質問応答に組み合わせて入力するのに適しています。
|
||||
- `type` (string) サポートされるタイプ:
|
||||
|
|
|
|||
|
|
@ -54,6 +54,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
(选填)会话 ID,需要基于之前的聊天记录继续对话,必须传之前消息的 conversation_id。
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
用于从特定消息继续或重新生成的父消息 ID。需要提供 `conversation_id`,且必须属于该对话。
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision/Video 能力时可用。
|
||||
- `type` (string) 支持类型:
|
||||
|
|
|
|||
Loading…
Reference in New Issue