From dbceb3067e5c1ce177bf60ec8617b53e2d831ac3 Mon Sep 17 00:00:00 2001 From: NVIDIAN Date: Wed, 15 Apr 2026 02:57:27 -0700 Subject: [PATCH] refactor(api): migrate console tag responses from marshal_with to BaseModel (#35208) Co-authored-by: ai-hpc --- api/controllers/console/tag/tags.py | 56 +++++++++++++------ .../controllers/console/tag/test_tags.py | 26 ++++++++- 2 files changed, 62 insertions(+), 20 deletions(-) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 39b84d3869..614bf03ea5 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,13 +1,14 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.enums import TagType from services.tag_service import ( @@ -18,17 +19,6 @@ from services.tag_service import ( UpdateTagPayload, ) -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} - - -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) - class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) @@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel): keyword: str | None = Field(None, description="Search keyword") +class TagResponse(ResponseModel): + id: str + name: str + type: str | None = None + binding_count: str | None = None + + @field_validator("type", mode="before") + @classmethod + def normalize_type(cls, value: TagType | str | None) -> str | None: + if value is None: + return None + if isinstance(value, TagType): + return value.value + return value + + @field_validator("binding_count", mode="before") + @classmethod + def normalize_binding_count(cls, value: int | str | None) -> str | None: + if value is None: + return None + return str(value) + + register_schema_models( console_ns, TagBasePayload, TagBindingPayload, TagBindingRemovePayload, TagListQueryParam, + TagResponse, ) @@ -69,14 +83,18 @@ class TagListApi(Resource): @console_ns.doc( params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} ) - @marshal_with(dataset_tag_fields) + @console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])}) def get(self): _, current_tenant_id = current_account_with_tenant() raw_args = request.args.to_dict() param = TagListQueryParam.model_validate(raw_args) tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) - return tags, 200 + serialized_tags = [ + TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags + ] + + return serialized_tags, 200 @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @@ -91,7 +109,9 @@ class TagListApi(Resource): payload = TagBasePayload.model_validate(console_ns.payload or {}) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type)) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = TagResponse.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource): binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + response = TagResponse.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index e89b89c8b1..2be5a21f28 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -1,9 +1,11 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, PropertyMock, patch import pytest from flask import Flask from werkzeug.exceptions import Forbidden +import controllers.console.tag.tags as module from controllers.console import console_ns from controllers.console.tag.tags import ( TagBindingCreateApi, @@ -83,13 +85,20 @@ class TestTagListApi: ), patch( "controllers.console.tag.tags.TagService.get_tags", - return_value=[{"id": "1", "name": "tag"}], + return_value=[ + SimpleNamespace( + id="1", + name="tag", + type=TagType.KNOWLEDGE, + binding_count=1, + ) + ], ), ): result, status = method(api) assert status == 200 - assert isinstance(result, list) + assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}] def test_post_success(self, app, admin_user, tag, payload_patch): api = TagListApi() @@ -113,6 +122,7 @@ class TestTagListApi: assert status == 200 assert result["name"] == "test-tag" + assert result["binding_count"] == "0" def test_post_forbidden(self, app, readonly_user, payload_patch): api = TagListApi() @@ -158,7 +168,7 @@ class TestTagUpdateDeleteApi: result, status = method(api, "tag-1") assert status == 200 - assert result["binding_count"] == 3 + assert result["binding_count"] == "3" def test_patch_forbidden(self, app, readonly_user, payload_patch): api = TagUpdateDeleteApi() @@ -277,3 +287,13 @@ class TestTagBindingDeleteApi: ): with pytest.raises(Forbidden): method(api) + + +class TestTagResponseModel: + def test_tag_response_normalizes_enum_type(self): + payload = module.TagResponse.model_validate( + {"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1} + ).model_dump(mode="json") + + assert payload["type"] == "knowledge" + assert payload["binding_count"] == "1"