diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index a037fe9254..b7fb01c6fe 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,7 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services @@ -52,11 +52,23 @@ class ChatRequestPayload(BaseModel): query: str files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None - conversation_id: UUID | None = None + conversation_id: str | None = Field(default=None, description="Conversation UUID") retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") + @field_validator("conversation_id", mode="before") + @classmethod + def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: + """Allow missing or blank conversation IDs; enforce UUID format when provided.""" + if not value: + return None + + try: + return helper.uuid_value(value) + except ValueError as exc: + raise ValueError("conversation_id must be a valid UUID") from exc + register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload) diff --git a/api/libs/helper.py b/api/libs/helper.py index 0506e0ed5f..a278ace6ad 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -107,7 +107,7 @@ def email(email): EmailStr = Annotated[str, AfterValidator(email)] -def uuid_value(value): +def uuid_value(value: Any) -> str: if value == "": return str(value) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py b/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py new file mode 100644 index 0000000000..1fb7e7009d --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py @@ -0,0 +1,25 @@ +import uuid + +import pytest +from pydantic import ValidationError + +from controllers.service_api.app.completion import ChatRequestPayload + + +def test_chat_request_payload_accepts_blank_conversation_id(): + payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": ""}) + + assert payload.conversation_id is None + + +def test_chat_request_payload_validates_uuid(): + conversation_id = str(uuid.uuid4()) + + payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": conversation_id}) + + assert payload.conversation_id == conversation_id + + +def test_chat_request_payload_rejects_invalid_uuid(): + with pytest.raises(ValidationError): + ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})