fix: Add the validation of doc_form in the Document-related service APIs.

This commit is contained in:
FFXN 2026-03-03 14:31:51 +08:00
parent 53c62fde33
commit 269bf883c2
4 changed files with 38 additions and 2 deletions

View File

@ -119,6 +119,14 @@ def _validate_indexing_technique(value: str | None) -> str | None:
return value
def _validate_doc_form(value: str | None) -> str | None:
if value is None:
return value
if value not in Dataset.DOC_FORM_LIST:
raise ValueError("Invalid doc_form.")
return value
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field("", max_length=400)

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -60,6 +60,13 @@ class DocumentTextCreatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
@field_validator("doc_form")
@classmethod
def validate_doc_form(cls, value: str) -> str:
if value not in Dataset.DOC_FORM_LIST:
raise ValueError("Invalid doc_form.")
return value
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -72,6 +79,13 @@ class DocumentTextUpdate(BaseModel):
doc_language: str = "English"
retrieval_model: RetrievalModel | None = None
@field_validator("doc_form")
@classmethod
def validate_doc_form(cls, value: str) -> str:
if value not in Dataset.DOC_FORM_LIST:
raise ValueError("Invalid doc_form.")
return value
@model_validator(mode="after")
def check_text_and_name(self) -> Self:
if self.text is not None and self.name is None:

View File

@ -51,6 +51,7 @@ class Dataset(Base):
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
PROVIDER_LIST = ["vendor", "external", None]
DOC_FORM_LIST = ["text_model", "qa_model", "hierarchical_model"]
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID)

View File

@ -1,8 +1,9 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -127,6 +128,18 @@ class KnowledgeConfig(BaseModel):
name: str | None = None
is_multimodal: bool = False
@field_validator("doc_form")
@classmethod
def validate_doc_form(cls, value: str) -> str:
valid_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
if value not in valid_forms:
raise ValueError("Invalid doc_form.")
return value
class SegmentCreateArgs(BaseModel):
content: str | None = None