refactor(api): migrate console tag responses from marshal_with to BaseModel (#35208)

Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
NVIDIAN 2026-04-15 02:57:27 -07:00 committed by GitHub
parent 425457cb16
commit dbceb3067e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 20 deletions

View File

@ -1,13 +1,14 @@
from typing import Literal
from flask import request
from flask_restx import Namespace, Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.enums import TagType
from services.tag_service import (
@ -18,17 +19,6 @@ from services.tag_service import (
UpdateTagPayload,
)
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel):
keyword: str | None = Field(None, description="Search keyword")
class TagResponse(ResponseModel):
id: str
name: str
type: str | None = None
binding_count: str | None = None
@field_validator("type", mode="before")
@classmethod
def normalize_type(cls, value: TagType | str | None) -> str | None:
if value is None:
return None
if isinstance(value, TagType):
return value.value
return value
@field_validator("binding_count", mode="before")
@classmethod
def normalize_binding_count(cls, value: int | str | None) -> str | None:
if value is None:
return None
return str(value)
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagListQueryParam,
TagResponse,
)
@ -69,14 +83,18 @@ class TagListApi(Resource):
@console_ns.doc(
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@marshal_with(dataset_tag_fields)
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
def get(self):
_, current_tenant_id = current_account_with_tenant()
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
return tags, 200
serialized_tags = [
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
]
return serialized_tags, 200
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@ -91,7 +109,9 @@ class TagListApi(Resource):
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
return response, 200
@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
return response, 200

View File

@ -1,9 +1,11 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
import controllers.console.tag.tags as module
from controllers.console import console_ns
from controllers.console.tag.tags import (
TagBindingCreateApi,
@ -83,13 +85,20 @@ class TestTagListApi:
),
patch(
"controllers.console.tag.tags.TagService.get_tags",
return_value=[{"id": "1", "name": "tag"}],
return_value=[
SimpleNamespace(
id="1",
name="tag",
type=TagType.KNOWLEDGE,
binding_count=1,
)
],
),
):
result, status = method(api)
assert status == 200
assert isinstance(result, list)
assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}]
def test_post_success(self, app, admin_user, tag, payload_patch):
api = TagListApi()
@ -113,6 +122,7 @@ class TestTagListApi:
assert status == 200
assert result["name"] == "test-tag"
assert result["binding_count"] == "0"
def test_post_forbidden(self, app, readonly_user, payload_patch):
api = TagListApi()
@ -158,7 +168,7 @@ class TestTagUpdateDeleteApi:
result, status = method(api, "tag-1")
assert status == 200
assert result["binding_count"] == 3
assert result["binding_count"] == "3"
def test_patch_forbidden(self, app, readonly_user, payload_patch):
api = TagUpdateDeleteApi()
@ -277,3 +287,13 @@ class TestTagBindingDeleteApi:
):
with pytest.raises(Forbidden):
method(api)
class TestTagResponseModel:
def test_tag_response_normalizes_enum_type(self):
payload = module.TagResponse.model_validate(
{"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1}
).model_dump(mode="json")
assert payload["type"] == "knowledge"
assert payload["binding_count"] == "1"