refactor(api): migrate console tag responses from marshal_with to BaseModel (#35208)

Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
NVIDIAN 2026-04-15 02:57:27 -07:00 committed by GitHub
parent 425457cb16
commit dbceb3067e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 20 deletions

View File

@ -1,13 +1,14 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_restx import Namespace, Resource, fields, marshal_with from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.enums import TagType from models.enums import TagType
from services.tag_service import ( from services.tag_service import (
@ -18,17 +19,6 @@ from services.tag_service import (
UpdateTagPayload, UpdateTagPayload,
) )
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class TagBasePayload(BaseModel): class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50) name: str = Field(description="Tag name", min_length=1, max_length=50)
@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel):
keyword: str | None = Field(None, description="Search keyword") keyword: str | None = Field(None, description="Search keyword")
class TagResponse(ResponseModel):
id: str
name: str
type: str | None = None
binding_count: str | None = None
@field_validator("type", mode="before")
@classmethod
def normalize_type(cls, value: TagType | str | None) -> str | None:
if value is None:
return None
if isinstance(value, TagType):
return value.value
return value
@field_validator("binding_count", mode="before")
@classmethod
def normalize_binding_count(cls, value: int | str | None) -> str | None:
if value is None:
return None
return str(value)
register_schema_models( register_schema_models(
console_ns, console_ns,
TagBasePayload, TagBasePayload,
TagBindingPayload, TagBindingPayload,
TagBindingRemovePayload, TagBindingRemovePayload,
TagListQueryParam, TagListQueryParam,
TagResponse,
) )
@ -69,14 +83,18 @@ class TagListApi(Resource):
@console_ns.doc( @console_ns.doc(
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
) )
@marshal_with(dataset_tag_fields) @console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
raw_args = request.args.to_dict() raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args) param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
return tags, 200 serialized_tags = [
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
]
return serialized_tags, 200
@console_ns.expect(console_ns.models[TagBasePayload.__name__]) @console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required @setup_required
@ -91,7 +109,9 @@ class TagListApi(Resource):
payload = TagBasePayload.model_validate(console_ns.payload or {}) payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type)) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
return response, 200 return response, 200
@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource):
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
return response, 200 return response, 200

View File

@ -1,9 +1,11 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import MagicMock, PropertyMock, patch
import pytest import pytest
from flask import Flask from flask import Flask
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import controllers.console.tag.tags as module
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.tag.tags import ( from controllers.console.tag.tags import (
TagBindingCreateApi, TagBindingCreateApi,
@ -83,13 +85,20 @@ class TestTagListApi:
), ),
patch( patch(
"controllers.console.tag.tags.TagService.get_tags", "controllers.console.tag.tags.TagService.get_tags",
return_value=[{"id": "1", "name": "tag"}], return_value=[
SimpleNamespace(
id="1",
name="tag",
type=TagType.KNOWLEDGE,
binding_count=1,
)
],
), ),
): ):
result, status = method(api) result, status = method(api)
assert status == 200 assert status == 200
assert isinstance(result, list) assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}]
def test_post_success(self, app, admin_user, tag, payload_patch): def test_post_success(self, app, admin_user, tag, payload_patch):
api = TagListApi() api = TagListApi()
@ -113,6 +122,7 @@ class TestTagListApi:
assert status == 200 assert status == 200
assert result["name"] == "test-tag" assert result["name"] == "test-tag"
assert result["binding_count"] == "0"
def test_post_forbidden(self, app, readonly_user, payload_patch): def test_post_forbidden(self, app, readonly_user, payload_patch):
api = TagListApi() api = TagListApi()
@ -158,7 +168,7 @@ class TestTagUpdateDeleteApi:
result, status = method(api, "tag-1") result, status = method(api, "tag-1")
assert status == 200 assert status == 200
assert result["binding_count"] == 3 assert result["binding_count"] == "3"
def test_patch_forbidden(self, app, readonly_user, payload_patch): def test_patch_forbidden(self, app, readonly_user, payload_patch):
api = TagUpdateDeleteApi() api = TagUpdateDeleteApi()
@ -277,3 +287,13 @@ class TestTagBindingDeleteApi:
): ):
with pytest.raises(Forbidden): with pytest.raises(Forbidden):
method(api) method(api)
class TestTagResponseModel:
def test_tag_response_normalizes_enum_type(self):
payload = module.TagResponse.model_validate(
{"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1}
).model_dump(mode="json")
assert payload["type"] == "knowledge"
assert payload["binding_count"] == "1"