From e78558bc06ce63c6ddf7de038bdddfd1a6a8374d Mon Sep 17 00:00:00 2001 From: NVIDIAN Date: Tue, 14 Apr 2026 11:12:40 -0700 Subject: [PATCH] refactor(api): migrate dataset hit-testing response model to BaseModel (#35192) Co-authored-by: ai-hpc --- .../console/datasets/hit_testing.py | 129 +++++++++++++----- .../console/datasets/test_hit_testing.py | 51 +++++++ 2 files changed, 144 insertions(+), 36 deletions(-) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index e62be13c2f..36a7a4bb0e 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,13 +1,13 @@ -from flask_restx import Resource, fields +from __future__ import annotations -from controllers.common.schema import register_schema_model -from fields.hit_testing_fields import ( - child_chunk_fields, - document_fields, - files_fields, - hit_testing_record_fields, - segment_fields, -) +from datetime import datetime +from typing import Any + +from flask_restx import Resource +from pydantic import Field, field_validator + +from controllers.common.schema import register_schema_models +from fields.base import ResponseModel from libs.login import login_required from .. import console_ns @@ -18,39 +18,92 @@ from ..wraps import ( setup_required, ) -register_schema_model(console_ns, HitTestingPayload) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def _get_or_create_model(model_name: str, field_def): - """Get or create a flask_restx model to avoid dict type issues in Swagger.""" - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing +class HitTestingDocument(ResponseModel): + id: str | None = None + data_source_type: str | None = None + name: str | None = None + doc_type: str | None = None + doc_metadata: Any | None = None -# Register models for flask_restx to avoid dict type issues in Swagger -document_model = _get_or_create_model("HitTestingDocument", document_fields) +class HitTestingSegment(ResponseModel): + id: str | None = None + position: int | None = None + document_id: str | None = None + content: str | None = None + sign_content: str | None = None + answer: str | None = None + word_count: int | None = None + tokens: int | None = None + keywords: list[str] = Field(default_factory=list) + index_node_id: str | None = None + index_node_hash: str | None = None + hit_count: int | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + status: str | None = None + created_by: str | None = None + created_at: int | None = None + indexing_at: int | None = None + completed_at: int | None = None + error: str | None = None + stopped_at: int | None = None + document: HitTestingDocument | None = None -segment_fields_copy = segment_fields.copy() -segment_fields_copy["document"] = fields.Nested(document_model) -segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) -files_model = _get_or_create_model("HitTestingFile", files_fields) -hit_testing_record_fields_copy = hit_testing_record_fields.copy() -hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) -hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) -hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) -hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) +class HitTestingChildChunk(ResponseModel): + id: str | None = None + content: str | None = None + position: int | None = None + score: float | None = None -# Response model for hit testing API -hit_testing_response_fields = { - "query": fields.String, - "records": fields.List(fields.Nested(hit_testing_record_model)), -} -hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + +class HitTestingFile(ResponseModel): + id: str | None = None + name: str | None = None + size: int | None = None + extension: str | None = None + mime_type: str | None = None + source_url: str | None = None + + +class HitTestingRecord(ResponseModel): + segment: HitTestingSegment | None = None + child_chunks: list[HitTestingChildChunk] = Field(default_factory=list) + score: float | None = None + tsne_position: Any | None = None + files: list[HitTestingFile] = Field(default_factory=list) + summary: str | None = None + + +class HitTestingResponse(ResponseModel): + query: str + records: list[HitTestingRecord] = Field(default_factory=list) + + +register_schema_models( + console_ns, + HitTestingPayload, + HitTestingDocument, + HitTestingSegment, + HitTestingChildChunk, + HitTestingFile, + HitTestingRecord, + HitTestingResponse, +) @console_ns.route("/datasets//hit-testing") @@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) + @console_ns.response( + 200, + "Hit testing completed successfully", + model=console_ns.models[HitTestingResponse.__name__], + ) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required @@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): args = payload.model_dump(exclude_none=True) self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) + return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 726c0a5cf3..09ed2aaf69 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -99,6 +99,57 @@ class TestHitTestingApi: assert "records" in result assert result["records"] == [] + def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + } + records = [ + { + "segment": None, + "child_chunks": [], + "score": None, + "tsne_position": None, + "files": [], + "summary": None, + } + ] + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": payload["query"], "records": records}, + ), + ): + result = method(api, dataset_id) + + assert result["query"] == payload["query"] + assert result["records"] == records + def test_hit_testing_dataset_not_found(self, app, dataset_id): api = HitTestingApi() method = unwrap(api.post)