diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index bbacf076c7..21c289ca82 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, ValidationInfo, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services +from constants import UUID_NIL from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( @@ -87,7 +88,7 @@ def _validate_parent_message_request( conversation_id: str | None, parent_message_id: str | None, ) -> None: - if not parent_message_id: + if not parent_message_id or parent_message_id == UUID_NIL: return if not conversation_id: diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 452d0efb67..88255bdb53 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -126,6 +126,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model=app_model, conversation_id=conversation_id, user=user ) + self._validate_parent_message_for_service_api( + app_model=app_model, + user=user, + conversation=conversation, + parent_message_id=args.get("parent_message_id"), + invoke_from=invoke_from, + ) + # parse files # TODO(QuantumGhost): Move file parsing logic to the API controller layer # for better separation of concerns. diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 8af412add4..38b261a17a 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -104,6 +104,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation = ConversationService.get_conversation( app_model=app_model, conversation_id=conversation_id, user=user ) + + self._validate_parent_message_for_service_api( + app_model=app_model, + user=user, + conversation=conversation, + parent_message_id=args.get("parent_message_id"), + invoke_from=invoke_from, + ) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index a94fe1731f..fb24e87826 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -96,6 +96,14 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation = ConversationService.get_conversation( app_model=app_model, conversation_id=conversation_id, user=user ) + + self._validate_parent_message_for_service_api( + app_model=app_model, + user=user, + conversation=conversation, + parent_message_id=args.get("parent_message_id"), + invoke_from=invoke_from, + ) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 3a00c08152..d58ff3a67c 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -35,6 +35,7 @@ from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Me from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError +from services.message_service import MessageService logger = logging.getLogger(__name__) @@ -91,6 +92,28 @@ class MessageBasedAppGenerator(BaseAppGenerator): return UUID_NIL return parent_message_id + def _validate_parent_message_for_service_api( + self, + *, + app_model: App, + user: Union[Account, EndUser], + conversation: Conversation | None, + parent_message_id: str | None, + invoke_from: InvokeFrom, + ) -> None: + if invoke_from != InvokeFrom.SERVICE_API: + return + + if not parent_message_id or parent_message_id == UUID_NIL: + return + + if not conversation: + raise ConversationNotExistsError("Conversation not exists") + + parent_message = MessageService.get_message(app_model=app_model, user=user, message_id=parent_message_id) + if parent_message.conversation_id != conversation.id: + raise MessageNotExistsError("Message not exists") + def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig: if conversation: stmt = select(AppModelConfig).where( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_chat_parent_message_validation.py b/api/tests/unit_tests/controllers/service_api/app/test_chat_parent_message_validation.py index 808d0c6a18..dde0bb08ed 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_chat_parent_message_validation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_chat_parent_message_validation.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from werkzeug.exceptions import BadRequest, NotFound +from constants import UUID_NIL from controllers.service_api.app.completion import _validate_parent_message_request from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError @@ -25,6 +26,25 @@ def test_validate_parent_message_skips_when_missing(): get_message.assert_not_called() +def test_validate_parent_message_skips_uuid_nil(): + app_model = object() + end_user = object() + + with ( + patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation, + patch("controllers.service_api.app.completion.MessageService.get_message") as get_message, + ): + _validate_parent_message_request( + app_model=app_model, + end_user=end_user, + conversation_id=None, + parent_message_id=UUID_NIL, + ) + + get_conversation.assert_not_called() + get_message.assert_not_called() + + def test_validate_parent_message_requires_conversation_id(): app_model = object() end_user = object() diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator_parent_message.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator_parent_message.py new file mode 100644 index 0000000000..3cb7a0f62a --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator_parent_message.py @@ -0,0 +1,90 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from constants import UUID_NIL +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError + + +def test_validate_parent_message_service_api_skips_missing(): + generator = MessageBasedAppGenerator() + + with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message: + generator._validate_parent_message_for_service_api( + app_model=object(), + user=object(), + conversation=None, + parent_message_id=None, + invoke_from=InvokeFrom.SERVICE_API, + ) + + get_message.assert_not_called() + + +def test_validate_parent_message_service_api_skips_uuid_nil(): + generator = MessageBasedAppGenerator() + + with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message: + generator._validate_parent_message_for_service_api( + app_model=object(), + user=object(), + conversation=None, + parent_message_id=UUID_NIL, + invoke_from=InvokeFrom.SERVICE_API, + ) + + get_message.assert_not_called() + + +def test_validate_parent_message_service_api_requires_conversation(): + generator = MessageBasedAppGenerator() + + with pytest.raises(ConversationNotExistsError): + generator._validate_parent_message_for_service_api( + app_model=object(), + user=object(), + conversation=None, + parent_message_id="parent-id", + invoke_from=InvokeFrom.SERVICE_API, + ) + + +def test_validate_parent_message_service_api_mismatch_conversation(): + generator = MessageBasedAppGenerator() + conversation = SimpleNamespace(id="conversation-id") + parent_message = SimpleNamespace(conversation_id="different-id") + + with patch( + "core.app.apps.message_based_app_generator.MessageService.get_message", + return_value=parent_message, + ): + with pytest.raises(MessageNotExistsError): + generator._validate_parent_message_for_service_api( + app_model=object(), + user=object(), + conversation=conversation, + parent_message_id="parent-id", + invoke_from=InvokeFrom.SERVICE_API, + ) + + +def test_validate_parent_message_service_api_matches_conversation(): + generator = MessageBasedAppGenerator() + conversation = SimpleNamespace(id="conversation-id") + parent_message = SimpleNamespace(conversation_id="conversation-id") + + with patch( + "core.app.apps.message_based_app_generator.MessageService.get_message", + return_value=parent_message, + ): + generator._validate_parent_message_for_service_api( + app_model=object(), + user=object(), + conversation=conversation, + parent_message_id="parent-id", + invoke_from=InvokeFrom.SERVICE_API, + )