mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 17:47:30 +08:00
refactor(api): migrate console tag responses from marshal_with to BaseModel (#35208)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
parent
425457cb16
commit
dbceb3067e
@ -1,13 +1,14 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
|
from fields.base import ResponseModel
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.enums import TagType
|
from models.enums import TagType
|
||||||
from services.tag_service import (
|
from services.tag_service import (
|
||||||
@ -18,17 +19,6 @@ from services.tag_service import (
|
|||||||
UpdateTagPayload,
|
UpdateTagPayload,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_tag_fields = {
|
|
||||||
"id": fields.String,
|
|
||||||
"name": fields.String,
|
|
||||||
"type": fields.String,
|
|
||||||
"binding_count": fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
|
||||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
|
||||||
|
|
||||||
|
|
||||||
class TagBasePayload(BaseModel):
|
class TagBasePayload(BaseModel):
|
||||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||||
@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel):
|
|||||||
keyword: str | None = Field(None, description="Search keyword")
|
keyword: str | None = Field(None, description="Search keyword")
|
||||||
|
|
||||||
|
|
||||||
|
class TagResponse(ResponseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: str | None = None
|
||||||
|
binding_count: str | None = None
|
||||||
|
|
||||||
|
@field_validator("type", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_type(cls, value: TagType | str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, TagType):
|
||||||
|
return value.value
|
||||||
|
return value
|
||||||
|
|
||||||
|
@field_validator("binding_count", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_binding_count(cls, value: int | str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(
|
register_schema_models(
|
||||||
console_ns,
|
console_ns,
|
||||||
TagBasePayload,
|
TagBasePayload,
|
||||||
TagBindingPayload,
|
TagBindingPayload,
|
||||||
TagBindingRemovePayload,
|
TagBindingRemovePayload,
|
||||||
TagListQueryParam,
|
TagListQueryParam,
|
||||||
|
TagResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -69,14 +83,18 @@ class TagListApi(Resource):
|
|||||||
@console_ns.doc(
|
@console_ns.doc(
|
||||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||||
)
|
)
|
||||||
@marshal_with(dataset_tag_fields)
|
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||||
def get(self):
|
def get(self):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raw_args = request.args.to_dict()
|
raw_args = request.args.to_dict()
|
||||||
param = TagListQueryParam.model_validate(raw_args)
|
param = TagListQueryParam.model_validate(raw_args)
|
||||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||||
|
|
||||||
return tags, 200
|
serialized_tags = [
|
||||||
|
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
|
||||||
|
]
|
||||||
|
|
||||||
|
return serialized_tags, 200
|
||||||
|
|
||||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@ -91,7 +109,9 @@ class TagListApi(Resource):
|
|||||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||||
|
|
||||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
response = TagResponse.model_validate(
|
||||||
|
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||||
|
).model_dump(mode="json")
|
||||||
|
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource):
|
|||||||
|
|
||||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||||
|
|
||||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
response = TagResponse.model_validate(
|
||||||
|
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||||
|
).model_dump(mode="json")
|
||||||
|
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock, PropertyMock, patch
|
from unittest.mock import MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
import controllers.console.tag.tags as module
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.tag.tags import (
|
from controllers.console.tag.tags import (
|
||||||
TagBindingCreateApi,
|
TagBindingCreateApi,
|
||||||
@ -83,13 +85,20 @@ class TestTagListApi:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.tag.tags.TagService.get_tags",
|
"controllers.console.tag.tags.TagService.get_tags",
|
||||||
return_value=[{"id": "1", "name": "tag"}],
|
return_value=[
|
||||||
|
SimpleNamespace(
|
||||||
|
id="1",
|
||||||
|
name="tag",
|
||||||
|
type=TagType.KNOWLEDGE,
|
||||||
|
binding_count=1,
|
||||||
|
)
|
||||||
|
],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result, status = method(api)
|
result, status = method(api)
|
||||||
|
|
||||||
assert status == 200
|
assert status == 200
|
||||||
assert isinstance(result, list)
|
assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}]
|
||||||
|
|
||||||
def test_post_success(self, app, admin_user, tag, payload_patch):
|
def test_post_success(self, app, admin_user, tag, payload_patch):
|
||||||
api = TagListApi()
|
api = TagListApi()
|
||||||
@ -113,6 +122,7 @@ class TestTagListApi:
|
|||||||
|
|
||||||
assert status == 200
|
assert status == 200
|
||||||
assert result["name"] == "test-tag"
|
assert result["name"] == "test-tag"
|
||||||
|
assert result["binding_count"] == "0"
|
||||||
|
|
||||||
def test_post_forbidden(self, app, readonly_user, payload_patch):
|
def test_post_forbidden(self, app, readonly_user, payload_patch):
|
||||||
api = TagListApi()
|
api = TagListApi()
|
||||||
@ -158,7 +168,7 @@ class TestTagUpdateDeleteApi:
|
|||||||
result, status = method(api, "tag-1")
|
result, status = method(api, "tag-1")
|
||||||
|
|
||||||
assert status == 200
|
assert status == 200
|
||||||
assert result["binding_count"] == 3
|
assert result["binding_count"] == "3"
|
||||||
|
|
||||||
def test_patch_forbidden(self, app, readonly_user, payload_patch):
|
def test_patch_forbidden(self, app, readonly_user, payload_patch):
|
||||||
api = TagUpdateDeleteApi()
|
api = TagUpdateDeleteApi()
|
||||||
@ -277,3 +287,13 @@ class TestTagBindingDeleteApi:
|
|||||||
):
|
):
|
||||||
with pytest.raises(Forbidden):
|
with pytest.raises(Forbidden):
|
||||||
method(api)
|
method(api)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTagResponseModel:
|
||||||
|
def test_tag_response_normalizes_enum_type(self):
|
||||||
|
payload = module.TagResponse.model_validate(
|
||||||
|
{"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1}
|
||||||
|
).model_dump(mode="json")
|
||||||
|
|
||||||
|
assert payload["type"] == "knowledge"
|
||||||
|
assert payload["binding_count"] == "1"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user