mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 10:27:00 +08:00
refactor(api): migrate console tag responses from marshal_with to BaseModel (#35208)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
parent
425457cb16
commit
dbceb3067e
@ -1,13 +1,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
@ -18,17 +19,6 @@ from services.tag_service import (
|
||||
UpdateTagPayload,
|
||||
)
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel):
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
class TagResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str | None = None
|
||||
binding_count: str | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def normalize_type(cls, value: TagType | str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, TagType):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@field_validator("binding_count", mode="before")
|
||||
@classmethod
|
||||
def normalize_binding_count(cls, value: int | str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
TagResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -69,14 +83,18 @@ class TagListApi(Resource):
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@marshal_with(dataset_tag_fields)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
|
||||
return tags, 200
|
||||
serialized_tags = [
|
||||
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
|
||||
]
|
||||
|
||||
return serialized_tags, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@ -91,7 +109,9 @@ class TagListApi(Resource):
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
|
||||
return response, 200
|
||||
|
||||
@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import controllers.console.tag.tags as module
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.tag.tags import (
|
||||
TagBindingCreateApi,
|
||||
@ -83,13 +85,20 @@ class TestTagListApi:
|
||||
),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.get_tags",
|
||||
return_value=[{"id": "1", "name": "tag"}],
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
id="1",
|
||||
name="tag",
|
||||
type=TagType.KNOWLEDGE,
|
||||
binding_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert isinstance(result, list)
|
||||
assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}]
|
||||
|
||||
def test_post_success(self, app, admin_user, tag, payload_patch):
|
||||
api = TagListApi()
|
||||
@ -113,6 +122,7 @@ class TestTagListApi:
|
||||
|
||||
assert status == 200
|
||||
assert result["name"] == "test-tag"
|
||||
assert result["binding_count"] == "0"
|
||||
|
||||
def test_post_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagListApi()
|
||||
@ -158,7 +168,7 @@ class TestTagUpdateDeleteApi:
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
assert status == 200
|
||||
assert result["binding_count"] == 3
|
||||
assert result["binding_count"] == "3"
|
||||
|
||||
def test_patch_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagUpdateDeleteApi()
|
||||
@ -277,3 +287,13 @@ class TestTagBindingDeleteApi:
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestTagResponseModel:
|
||||
def test_tag_response_normalizes_enum_type(self):
|
||||
payload = module.TagResponse.model_validate(
|
||||
{"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1}
|
||||
).model_dump(mode="json")
|
||||
|
||||
assert payload["type"] == "knowledge"
|
||||
assert payload["binding_count"] == "1"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user