refactor: replace dict params with BaseModel payloads in TagService (#34422)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
YBoy 2026-04-07 06:20:02 +02:00 committed by GitHub
parent b9c122e7f4
commit f67a811f7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 123 additions and 61 deletions

View File

@ -9,7 +9,14 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.tag_service import TagService from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
dataset_tag_fields = { dataset_tag_fields = {
"id": fields.String, "id": fields.String,
@ -25,19 +32,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace):
class TagBasePayload(BaseModel): class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50) name: str = Field(description="Tag name", min_length=1, max_length=50)
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") type: TagType = Field(description="Tag type")
class TagBindingPayload(BaseModel): class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind") tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to") target_id: str = Field(description="Target ID to bind tags to")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") type: TagType = Field(description="Tag type")
class TagBindingRemovePayload(BaseModel): class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove") tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from") target_id: str = Field(description="Target ID to unbind tag from")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel): class TagListQueryParam(BaseModel):
@ -82,7 +89,7 @@ class TagListApi(Resource):
raise Forbidden() raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {}) payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump()) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -103,7 +110,7 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden() raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {}) payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id) tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
@ -136,7 +143,9 @@ class TagBindingCreateApi(Resource):
raise Forbidden() raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {}) payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump()) TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -154,6 +163,8 @@ class TagBindingDeleteApi(Resource):
raise Forbidden() raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump()) TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -22,10 +22,17 @@ from fields.tag_fields import DataSetTag
from libs.login import current_user from libs.login import current_user
from models.account import Account from models.account import Account
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.enums import TagType
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -513,7 +520,7 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden() raise Forbidden()
payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
response = DataSetTag.model_validate( response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -536,9 +543,8 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden() raise Forbidden()
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
params = {"name": payload.name, "type": "knowledge"}
tag_id = payload.tag_id tag_id = payload.tag_id
tag = TagService.update_tags(params, tag_id) tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
@ -585,7 +591,9 @@ class DatasetTagBindingApi(DatasetApiResource):
raise Forbidden() raise Forbidden()
payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"}) TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204 return "", 204
@ -609,7 +617,9 @@ class DatasetTagUnbindingApi(DatasetApiResource):
raise Forbidden() raise Forbidden()
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"}) TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204 return "", 204

View File

@ -2,6 +2,7 @@ import uuid
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import current_user from flask_login import current_user
from pydantic import BaseModel, Field
from sqlalchemy import func, select from sqlalchemy import func, select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -11,6 +12,28 @@ from models.enums import TagType
from models.model import App, Tag, TagBinding from models.model import App, Tag, TagBinding
class SaveTagPayload(BaseModel):
name: str = Field(min_length=1, max_length=50)
type: TagType
class UpdateTagPayload(BaseModel):
name: str = Field(min_length=1, max_length=50)
type: TagType
class TagBindingCreatePayload(BaseModel):
tag_ids: list[str]
target_id: str
type: TagType
class TagBindingDeletePayload(BaseModel):
tag_id: str
target_id: str
type: TagType
class TagService: class TagService:
@staticmethod @staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None): def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
@ -78,12 +101,12 @@ class TagService:
return tags or [] return tags or []
@staticmethod @staticmethod
def save_tags(args: dict) -> Tag: def save_tags(payload: SaveTagPayload) -> Tag:
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]): if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name):
raise ValueError("Tag name already exists") raise ValueError("Tag name already exists")
tag = Tag( tag = Tag(
name=args["name"], name=payload.name,
type=TagType(args["type"]), type=TagType(payload.type),
created_by=current_user.id, created_by=current_user.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
) )
@ -93,13 +116,24 @@ class TagService:
return tag return tag
@staticmethod @staticmethod
def update_tags(args: dict, tag_id: str) -> Tag: def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag:
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
raise ValueError("Tag name already exists")
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
if not tag: if not tag:
raise NotFound("Tag not found") raise NotFound("Tag not found")
tag.name = args["name"] if payload.name != tag.name:
existing = db.session.scalar(
select(Tag)
.where(
Tag.name == payload.name,
Tag.tenant_id == current_user.current_tenant_id,
Tag.type == tag.type,
Tag.id != tag_id,
)
.limit(1)
)
if existing:
raise ValueError("Tag name already exists")
tag.name = payload.name
db.session.commit() db.session.commit()
return tag return tag
@ -122,21 +156,19 @@ class TagService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def save_tag_binding(args): def save_tag_binding(payload: TagBindingCreatePayload):
# check if target exists TagService.check_target_exists(payload.type, payload.target_id)
TagService.check_target_exists(args["type"], args["target_id"]) for tag_id in payload.tag_ids:
# save tag binding
for tag_id in args["tag_ids"]:
tag_binding = db.session.scalar( tag_binding = db.session.scalar(
select(TagBinding) select(TagBinding)
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) .where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id)
.limit(1) .limit(1)
) )
if tag_binding: if tag_binding:
continue continue
new_tag_binding = TagBinding( new_tag_binding = TagBinding(
tag_id=tag_id, tag_id=tag_id,
target_id=args["target_id"], target_id=payload.target_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
created_by=current_user.id, created_by=current_user.id,
) )
@ -144,17 +176,15 @@ class TagService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def delete_tag_binding(args): def delete_tag_binding(payload: TagBindingDeletePayload):
# check if target exists TagService.check_target_exists(payload.type, payload.target_id)
TagService.check_target_exists(args["type"], args["target_id"]) tag_binding = db.session.scalar(
# delete tag binding
tag_bindings = db.session.scalar(
select(TagBinding) select(TagBinding)
.where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == args["tag_id"]) .where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id)
.limit(1) .limit(1)
) )
if tag_bindings: if tag_binding:
db.session.delete(tag_bindings) db.session.delete(tag_binding)
db.session.commit() db.session.commit()
@staticmethod @staticmethod

View File

@ -970,8 +970,10 @@ class TestDatasetTagBindingApiPost:
result = api.post(_=None) result = api.post(_=None)
assert result == ("", 204) assert result == ("", 204)
from services.tag_service import TagBindingCreatePayload
mock_tag_svc.save_tag_binding.assert_called_once_with( mock_tag_svc.save_tag_binding.assert_called_once_with(
{"tag_ids": ["tag-1"], "target_id": "ds-1", "type": "knowledge"} TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge")
) )
@patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.current_user")
@ -1019,8 +1021,10 @@ class TestDatasetTagUnbindingApiPost:
result = api.post(_=None) result = api.post(_=None)
assert result == ("", 204) assert result == ("", 204)
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with( mock_tag_svc.delete_tag_binding.assert_called_once_with(
{"tag_id": "tag-1", "target_id": "ds-1", "type": "knowledge"} TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge")
) )
@patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.current_user")

View File

@ -12,7 +12,13 @@ from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset from models.dataset import Dataset
from models.enums import DataSourceType, TagType from models.enums import DataSourceType, TagType
from models.model import App, Tag, TagBinding from models.model import App, Tag, TagBinding
from services.tag_service import TagService from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
class TestTagService: class TestTagService:
@ -685,7 +691,7 @@ class TestTagService:
db_session_with_containers, mock_external_service_dependencies db_session_with_containers, mock_external_service_dependencies
) )
tag_args = {"name": "test_tag_name", "type": "knowledge"} tag_args = SaveTagPayload(name="test_tag_name", type="knowledge")
# Act: Execute the method under test # Act: Execute the method under test
result = TagService.save_tags(tag_args) result = TagService.save_tags(tag_args)
@ -725,7 +731,7 @@ class TestTagService:
) )
# Create first tag # Create first tag
tag_args = {"name": "duplicate_tag", "type": "app"} tag_args = SaveTagPayload(name="duplicate_tag", type="app")
TagService.save_tags(tag_args) TagService.save_tags(tag_args)
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
@ -749,11 +755,11 @@ class TestTagService:
) )
# Create a tag to update # Create a tag to update
tag_args = {"name": "original_name", "type": "knowledge"} tag_args = SaveTagPayload(name="original_name", type="knowledge")
tag = TagService.save_tags(tag_args) tag = TagService.save_tags(tag_args)
# Update args # Update args
update_args = {"name": "updated_name", "type": "knowledge"} update_args = UpdateTagPayload(name="updated_name", type="knowledge")
# Act: Execute the method under test # Act: Execute the method under test
result = TagService.update_tags(update_args, tag.id) result = TagService.update_tags(update_args, tag.id)
@ -793,7 +799,7 @@ class TestTagService:
non_existent_tag_id = str(uuid.uuid4()) non_existent_tag_id = str(uuid.uuid4())
update_args = {"name": "updated_name", "type": "knowledge"} update_args = UpdateTagPayload(name="updated_name", type="knowledge")
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info: with pytest.raises(NotFound) as exc_info:
@ -817,14 +823,14 @@ class TestTagService:
) )
# Create two tags # Create two tags
tag1_args = {"name": "first_tag", "type": "app"} tag1_args = SaveTagPayload(name="first_tag", type="app")
tag1 = TagService.save_tags(tag1_args) tag1 = TagService.save_tags(tag1_args)
tag2_args = {"name": "second_tag", "type": "app"} tag2_args = SaveTagPayload(name="second_tag", type="app")
tag2 = TagService.save_tags(tag2_args) tag2 = TagService.save_tags(tag2_args)
# Try to update second tag with first tag's name # Try to update second tag with first tag's name
update_args = {"name": "first_tag", "type": "app"} update_args = UpdateTagPayload(name="first_tag", type="app")
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
@ -988,8 +994,10 @@ class TestTagService:
dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Execute the method under test # Act: Execute the method under test
binding_args = {"type": "knowledge", "target_id": dataset.id, "tag_ids": [tag.id for tag in tags]} binding_payload = TagBindingCreatePayload(
TagService.save_tag_binding(binding_args) type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
)
TagService.save_tag_binding(binding_payload)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -1030,11 +1038,11 @@ class TestTagService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Create first binding # Create first binding
binding_args = {"type": "app", "target_id": app.id, "tag_ids": [tag.id]} binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id])
TagService.save_tag_binding(binding_args) TagService.save_tag_binding(binding_payload)
# Act: Try to create duplicate binding # Act: Try to create duplicate binding
TagService.save_tag_binding(binding_args) TagService.save_tag_binding(binding_payload)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
@ -1071,11 +1079,10 @@ class TestTagService:
non_existent_target_id = str(uuid.uuid4()) non_existent_target_id = str(uuid.uuid4())
# Act & Assert: Verify proper error handling # Act & Assert: Verify proper error handling
binding_args = {"type": "invalid_type", "target_id": non_existent_target_id, "tag_ids": [tag.id]} from pydantic import ValidationError
with pytest.raises(NotFound) as exc_info: with pytest.raises(ValidationError):
TagService.save_tag_binding(binding_args) TagBindingCreatePayload(type="invalid_type", target_id=non_existent_target_id, tag_ids=[tag.id])
assert "Invalid binding type" in str(exc_info.value)
def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
""" """
@ -1113,8 +1120,8 @@ class TestTagService:
assert binding_before is not None assert binding_before is not None
# Act: Execute the method under test # Act: Execute the method under test
delete_args = {"type": "knowledge", "target_id": dataset.id, "tag_id": tag.id} delete_payload = TagBindingDeletePayload(type="knowledge", target_id=dataset.id, tag_id=tag.id)
TagService.delete_tag_binding(delete_args) TagService.delete_tag_binding(delete_payload)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# Verify tag binding was deleted # Verify tag binding was deleted
@ -1149,8 +1156,8 @@ class TestTagService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Try to delete non-existent binding # Act: Try to delete non-existent binding
delete_args = {"type": "app", "target_id": app.id, "tag_id": tag.id} delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_id=tag.id)
TagService.delete_tag_binding(delete_args) TagService.delete_tag_binding(delete_payload)
# Assert: Verify the expected outcomes # Assert: Verify the expected outcomes
# No error should be raised, and database state should remain unchanged # No error should be raised, and database state should remain unchanged