refactor(api): migrate dataset document response schemas to BaseModel (#35298)

Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
NVIDIAN 2026-04-16 22:02:04 -07:00 committed by GitHub
parent dc3f992e6e
commit af21dc7df8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 130 additions and 45 deletions

View File

@ -3,18 +3,19 @@ import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import ExitStack from contextlib import ExitStack
from datetime import datetime
from typing import Any, Literal, cast from typing import Any, Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request, send_file from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with from flask_restx import Resource, marshal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from sqlalchemy import asc, desc, func, select from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.schema import get_or_create_model, register_schema_models from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from core.errors.error import ( from core.errors.error import (
LLMBadRequestError, LLMBadRequestError,
@ -29,11 +30,9 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db from extensions.ext_database import db
from fields.dataset_fields import dataset_fields from fields.base import ResponseModel
from fields.document_fields import ( from fields.document_fields import (
dataset_and_document_fields,
document_fields, document_fields,
document_metadata_fields,
document_status_fields, document_status_fields,
document_with_segments_fields, document_with_segments_fields,
) )
@ -72,27 +71,100 @@ from ..wraps import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Register models for flask_restx to avoid dict type issues in Swagger def _to_timestamp(value: datetime | int | None) -> int | None:
dataset_model = get_or_create_model("Dataset", dataset_fields) if isinstance(value, datetime):
return int(value.timestamp())
return value
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy() def _normalize_enum(value: Any) -> Any:
document_fields_copy["doc_metadata"] = fields.List( if isinstance(value, str) or value is None:
fields.Nested(document_metadata_model), attribute="doc_metadata_details" return value
) return getattr(value, "value", value)
document_model = get_or_create_model("Document", document_fields_copy)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy() class DatasetResponse(ResponseModel):
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) id: str
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) name: str
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) description: str | None = None
permission: str | None = None
data_source_type: str | None = None
indexing_technique: str | None = None
created_by: str | None = None
created_at: int | None = None
@field_validator("data_source_type", "indexing_technique", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class DocumentMetadataResponse(ResponseModel):
id: str
name: str
type: str
value: str | None = None
class DocumentResponse(ResponseModel):
id: str
position: int | None = None
data_source_type: str | None = None
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
data_source_detail_dict: Any = None
dataset_process_rule_id: str | None = None
name: str
created_from: str | None = None
created_by: str | None = None
created_at: int | None = None
tokens: int | None = None
indexing_status: str | None = None
error: str | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
archived: bool | None = None
display_status: str | None = None
word_count: int | None = None
hit_count: int | None = None
doc_form: str | None = None
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
summary_index_status: str | None = None
need_summary: bool | None = None
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("doc_metadata", mode="before")
@classmethod
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
if value is None:
return []
return value
@field_validator("created_at", "disabled_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class DocumentWithSegmentsResponse(DocumentResponse):
process_rule_dict: Any = None
completed_segments: int | None = None
total_segments: int | None = None
class DatasetAndDocumentResponse(ResponseModel):
dataset: DatasetResponse
documents: list[DocumentResponse]
batch: str
class DocumentRetryPayload(BaseModel): class DocumentRetryPayload(BaseModel):
@ -107,6 +179,11 @@ class GenerateSummaryPayload(BaseModel):
document_list: list[str] document_list: list[str]
class DocumentMetadataUpdatePayload(BaseModel):
doc_type: str | None = None
doc_metadata: Any = None
class DocumentDatasetListParam(BaseModel): class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.") page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.") limit: int = Field(20, title="Limit", description="Page size.")
@ -124,7 +201,13 @@ register_schema_models(
DocumentRetryPayload, DocumentRetryPayload,
DocumentRenamePayload, DocumentRenamePayload,
GenerateSummaryPayload, GenerateSummaryPayload,
DocumentMetadataUpdatePayload,
DocumentBatchDownloadZipPayload, DocumentBatchDownloadZipPayload,
DatasetResponse,
DocumentMetadataResponse,
DocumentResponse,
DocumentWithSegmentsResponse,
DatasetAndDocumentResponse,
) )
@ -357,10 +440,10 @@ class DatasetDocumentListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) @console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -398,7 +481,9 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
return {"dataset": dataset, "documents": documents, "batch": batch} return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@setup_required @setup_required
@login_required @login_required
@ -426,12 +511,13 @@ class DatasetInitApi(Resource):
@console_ns.doc("init_dataset") @console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents") @console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) @console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) @console_ns.response(
201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__]
)
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
@ -479,9 +565,9 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
response = {"dataset": dataset, "documents": documents, "batch": batch} return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
return response ).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
@ -988,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource):
@console_ns.doc("update_document_metadata") @console_ns.doc("update_document_metadata")
@console_ns.doc(description="Update document metadata") @console_ns.doc(description="Update document metadata")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__])
console_ns.model(
"UpdateDocumentMetadataRequest",
{
"doc_type": fields.String(description="Document type"),
"doc_metadata": fields.Raw(description="Document metadata"),
},
)
)
@console_ns.response(200, "Document metadata updated successfully") @console_ns.response(200, "Document metadata updated successfully")
@console_ns.response(404, "Document not found") @console_ns.response(404, "Document not found")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -1009,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource):
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
req_data = request.get_json() req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
doc_type = req_data.get("doc_type") doc_type = req_data.doc_type
doc_metadata = req_data.get("doc_metadata") doc_metadata = req_data.doc_metadata
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -1194,7 +1272,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(document_model) @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@ -1212,7 +1290,7 @@ class DocumentRenameApi(DocumentResource):
except services.errors.document.DocumentIndexingError: except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.") raise DocumentIndexingError("Cannot delete document during indexing.")
return document return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")

View File

@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -215,17 +216,23 @@ class TestDatasetDocumentListApi:
method = unwrap(api.post) method = unwrap(api.post)
payload = {"indexing_technique": "economy"} payload = {"indexing_technique": "economy"}
created_dataset = SimpleNamespace(id="ds-1", name="Dataset", indexing_technique="economy")
created_document = SimpleNamespace(id="doc-1", name="Document", doc_metadata_details=None)
with ( with (
app.test_request_context("/", json=payload), app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload), patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_document.DatasetService.get_dataset",
return_value=created_dataset,
),
patch( patch(
"controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate",
return_value=None, return_value=None,
), ),
patch( patch(
"controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id",
return_value=([MagicMock()], "batch-1"), return_value=([created_document], "batch-1"),
), ),
): ):
response = method(api, "ds-1") response = method(api, "ds-1")