mirror of https://github.com/langgenius/dify.git
feat: implement parent message validation for service API
Added a new validation method to check parent message IDs in the MessageBasedAppGenerator class, ensuring proper handling of UUID_NIL and conversation existence. Updated related app generators and added unit tests for comprehensive coverage.
This commit is contained in:
parent
df09acb74b
commit
6c9304d7ef
|
|
@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from constants import UUID_NIL
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import (
|
from controllers.service_api.app.error import (
|
||||||
|
|
@ -87,7 +88,7 @@ def _validate_parent_message_request(
|
||||||
conversation_id: str | None,
|
conversation_id: str | None,
|
||||||
parent_message_id: str | None,
|
parent_message_id: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not parent_message_id:
|
if not parent_message_id or parent_message_id == UUID_NIL:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
app_model=app_model, conversation_id=conversation_id, user=user
|
app_model=app_model, conversation_id=conversation_id, user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._validate_parent_message_for_service_api(
|
||||||
|
app_model=app_model,
|
||||||
|
user=user,
|
||||||
|
conversation=conversation,
|
||||||
|
parent_message_id=args.get("parent_message_id"),
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
)
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||||
# for better separation of concerns.
|
# for better separation of concerns.
|
||||||
|
|
|
||||||
|
|
@ -104,6 +104,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
conversation = ConversationService.get_conversation(
|
conversation = ConversationService.get_conversation(
|
||||||
app_model=app_model, conversation_id=conversation_id, user=user
|
app_model=app_model, conversation_id=conversation_id, user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._validate_parent_message_for_service_api(
|
||||||
|
app_model=app_model,
|
||||||
|
user=user,
|
||||||
|
conversation=conversation,
|
||||||
|
parent_message_id=args.get("parent_message_id"),
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
)
|
||||||
# get app model config
|
# get app model config
|
||||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,14 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||||
conversation = ConversationService.get_conversation(
|
conversation = ConversationService.get_conversation(
|
||||||
app_model=app_model, conversation_id=conversation_id, user=user
|
app_model=app_model, conversation_id=conversation_id, user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._validate_parent_message_for_service_api(
|
||||||
|
app_model=app_model,
|
||||||
|
user=user,
|
||||||
|
conversation=conversation,
|
||||||
|
parent_message_id=args.get("parent_message_id"),
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
)
|
||||||
# get app model config
|
# get app model config
|
||||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Me
|
||||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
|
from services.message_service import MessageService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -91,6 +92,28 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||||
return UUID_NIL
|
return UUID_NIL
|
||||||
return parent_message_id
|
return parent_message_id
|
||||||
|
|
||||||
|
def _validate_parent_message_for_service_api(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
app_model: App,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
conversation: Conversation | None,
|
||||||
|
parent_message_id: str | None,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
) -> None:
|
||||||
|
if invoke_from != InvokeFrom.SERVICE_API:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not parent_message_id or parent_message_id == UUID_NIL:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
raise ConversationNotExistsError("Conversation not exists")
|
||||||
|
|
||||||
|
parent_message = MessageService.get_message(app_model=app_model, user=user, message_id=parent_message_id)
|
||||||
|
if parent_message.conversation_id != conversation.id:
|
||||||
|
raise MessageNotExistsError("Message not exists")
|
||||||
|
|
||||||
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
|
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
|
||||||
if conversation:
|
if conversation:
|
||||||
stmt = select(AppModelConfig).where(
|
stmt = select(AppModelConfig).where(
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
|
from constants import UUID_NIL
|
||||||
from controllers.service_api.app.completion import _validate_parent_message_request
|
from controllers.service_api.app.completion import _validate_parent_message_request
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
|
|
@ -25,6 +26,25 @@ def test_validate_parent_message_skips_when_missing():
|
||||||
get_message.assert_not_called()
|
get_message.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_parent_message_skips_uuid_nil():
|
||||||
|
app_model = object()
|
||||||
|
end_user = object()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
|
||||||
|
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
|
||||||
|
):
|
||||||
|
_validate_parent_message_request(
|
||||||
|
app_model=app_model,
|
||||||
|
end_user=end_user,
|
||||||
|
conversation_id=None,
|
||||||
|
parent_message_id=UUID_NIL,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_conversation.assert_not_called()
|
||||||
|
get_message.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_validate_parent_message_requires_conversation_id():
|
def test_validate_parent_message_requires_conversation_id():
|
||||||
app_model = object()
|
app_model = object()
|
||||||
end_user = object()
|
end_user = object()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from constants import UUID_NIL
|
||||||
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
|
from services.errors.message import MessageNotExistsError
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_parent_message_service_api_skips_missing():
|
||||||
|
generator = MessageBasedAppGenerator()
|
||||||
|
|
||||||
|
with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message:
|
||||||
|
generator._validate_parent_message_for_service_api(
|
||||||
|
app_model=object(),
|
||||||
|
user=object(),
|
||||||
|
conversation=None,
|
||||||
|
parent_message_id=None,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_message.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_parent_message_service_api_skips_uuid_nil():
|
||||||
|
generator = MessageBasedAppGenerator()
|
||||||
|
|
||||||
|
with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message:
|
||||||
|
generator._validate_parent_message_for_service_api(
|
||||||
|
app_model=object(),
|
||||||
|
user=object(),
|
||||||
|
conversation=None,
|
||||||
|
parent_message_id=UUID_NIL,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_message.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_parent_message_service_api_requires_conversation():
|
||||||
|
generator = MessageBasedAppGenerator()
|
||||||
|
|
||||||
|
with pytest.raises(ConversationNotExistsError):
|
||||||
|
generator._validate_parent_message_for_service_api(
|
||||||
|
app_model=object(),
|
||||||
|
user=object(),
|
||||||
|
conversation=None,
|
||||||
|
parent_message_id="parent-id",
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_parent_message_service_api_mismatch_conversation():
|
||||||
|
generator = MessageBasedAppGenerator()
|
||||||
|
conversation = SimpleNamespace(id="conversation-id")
|
||||||
|
parent_message = SimpleNamespace(conversation_id="different-id")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"core.app.apps.message_based_app_generator.MessageService.get_message",
|
||||||
|
return_value=parent_message,
|
||||||
|
):
|
||||||
|
with pytest.raises(MessageNotExistsError):
|
||||||
|
generator._validate_parent_message_for_service_api(
|
||||||
|
app_model=object(),
|
||||||
|
user=object(),
|
||||||
|
conversation=conversation,
|
||||||
|
parent_message_id="parent-id",
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_parent_message_service_api_matches_conversation():
|
||||||
|
generator = MessageBasedAppGenerator()
|
||||||
|
conversation = SimpleNamespace(id="conversation-id")
|
||||||
|
parent_message = SimpleNamespace(conversation_id="conversation-id")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"core.app.apps.message_based_app_generator.MessageService.get_message",
|
||||||
|
return_value=parent_message,
|
||||||
|
):
|
||||||
|
generator._validate_parent_message_for_service_api(
|
||||||
|
app_model=object(),
|
||||||
|
user=object(),
|
||||||
|
conversation=conversation,
|
||||||
|
parent_message_id="parent-id",
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue