mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 11:10:19 +08:00
fix: Add the missing validation of doc_form in the service API. (#32892)
This commit is contained in:
parent
c8688ec371
commit
2068640a4b
@ -119,6 +119,14 @@ def _validate_indexing_technique(value: str | None) -> str | None:
|
||||
return value
|
||||
|
||||
|
||||
def _validate_doc_form(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
if value not in Dataset.DOC_FORM_LIST:
|
||||
raise ValueError("Invalid doc_form.")
|
||||
return value
|
||||
|
||||
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field("", max_length=400)
|
||||
@ -179,6 +187,14 @@ class IndexingEstimatePayload(BaseModel):
|
||||
raise ValueError("indexing_technique is required.")
|
||||
return result
|
||||
|
||||
@field_validator("doc_form")
|
||||
@classmethod
|
||||
def validate_doc_form(cls, value: str) -> str:
|
||||
result = _validate_doc_form(value)
|
||||
if result is None:
|
||||
return "text_model"
|
||||
return result
|
||||
|
||||
|
||||
class ConsoleDatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
|
||||
@ -4,7 +4,7 @@ from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -60,6 +60,13 @@ class DocumentTextCreatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
|
||||
@field_validator("doc_form")
|
||||
@classmethod
|
||||
def validate_doc_form(cls, value: str) -> str:
|
||||
if value not in Dataset.DOC_FORM_LIST:
|
||||
raise ValueError("Invalid doc_form.")
|
||||
return value
|
||||
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@ -72,6 +79,13 @@ class DocumentTextUpdate(BaseModel):
|
||||
doc_language: str = "English"
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
|
||||
@field_validator("doc_form")
|
||||
@classmethod
|
||||
def validate_doc_form(cls, value: str) -> str:
|
||||
if value not in Dataset.DOC_FORM_LIST:
|
||||
raise ValueError("Invalid doc_form.")
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_text_and_name(self) -> Self:
|
||||
if self.text is not None and self.name is None:
|
||||
|
||||
@ -19,6 +19,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file
|
||||
@ -51,6 +52,7 @@ class Dataset(Base):
|
||||
|
||||
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
|
||||
PROVIDER_LIST = ["vendor", "external", None]
|
||||
DOC_FORM_LIST = [member.value for member in IndexStructureType]
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
@ -127,6 +128,18 @@ class KnowledgeConfig(BaseModel):
|
||||
name: str | None = None
|
||||
is_multimodal: bool = False
|
||||
|
||||
@field_validator("doc_form")
|
||||
@classmethod
|
||||
def validate_doc_form(cls, value: str) -> str:
|
||||
valid_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
]
|
||||
if value not in valid_forms:
|
||||
raise ValueError("Invalid doc_form.")
|
||||
return value
|
||||
|
||||
|
||||
class SegmentCreateArgs(BaseModel):
|
||||
content: str | None = None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user