refactor: TagService to accept db.session explicitly (#37416)

This commit is contained in:
cn7shi 2026-06-15 10:04:28 +08:00 committed by GitHub
parent c6b3e525d1
commit e0773c4d8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 31 additions and 14 deletions

View File

@ -16,6 +16,7 @@ from controllers.console.wraps import (
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import login_required
from models import Account
@ -101,7 +102,7 @@ class TagListApi(Resource):
def get(self, current_tenant_id: str):
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
tags = TagService.get_tags(db.session(), param.type, current_tenant_id, param.keyword)
serialized_tags = [
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags

View File

@ -22,6 +22,7 @@ from controllers.service_api.wraps import (
)
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from graphon.model_runtime.entities.model_entities import ModelType
@ -608,7 +609,7 @@ class DatasetTagsApi(DatasetApiResource):
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
tags = TagService.get_tags(db.session(), "knowledge", cid)
return dump_response(KnowledgeTagListResponse, tags), 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])

View File

@ -6,6 +6,7 @@ from flask_login import current_user
from pydantic import BaseModel, Field
from sqlalchemy import delete, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
@ -38,7 +39,7 @@ class TagBindingDeletePayload(BaseModel):
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
def get_tags(session: Session, tag_type: str, current_tenant_id: str, keyword: str | None = None):
stmt = (
select(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
@ -50,7 +51,7 @@ class TagService:
escaped_keyword = escape_like_pattern(keyword)
stmt = stmt.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
stmt = stmt.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = list(db.session.execute(stmt.order_by(Tag.created_at.desc())).all())
results: list = list(session.execute(stmt.order_by(Tag.created_at.desc())).all())
return results
@staticmethod

View File

@ -23,6 +23,12 @@ from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
class SessionMatcher:
def __eq__(self, other):
return isinstance(other, Session)
import services
from controllers.service_api.dataset.dataset import (
DatasetCreatePayload,
@ -998,7 +1004,7 @@ class TestDatasetTagsApiGet:
assert status == 200
assert response == [{"id": "tag-1", "name": "Test Tag", "type": "knowledge", "binding_count": "0"}]
mock_tag_svc.get_tags.assert_called_once_with("knowledge", "tenant-1")
mock_tag_svc.get_tags.assert_called_once_with(SessionMatcher(), "knowledge", "tenant-1")
@patch("controllers.service_api.dataset.dataset.current_user")
def test_list_tags_from_db(

View File

@ -246,7 +246,7 @@ class TestTagService:
)
# Act: Execute the method under test
result = TagService.get_tags("knowledge", tenant.id)
result = TagService.get_tags(db_session_with_containers, "knowledge", tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -299,7 +299,7 @@ class TestTagService:
db_session_with_containers.commit()
# Act: Execute the method under test with keyword filter
result = TagService.get_tags("app", tenant.id, keyword="development")
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="development")
# Assert: Verify the expected outcomes
assert result is not None
@ -310,7 +310,7 @@ class TestTagService:
assert "development" in tag_result.name.lower()
# Verify no results for non-matching keyword
result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
result_no_match = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="nonexistent")
assert len(result_no_match) == 0
def test_get_tags_with_special_characters_in_keyword(
@ -371,22 +371,22 @@ class TestTagService:
db_session_with_containers.commit()
# Act & Assert: Test 1 - Search with % character
result = TagService.get_tags("app", tenant.id, keyword="50%")
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="50%")
assert len(result) == 1
assert result[0].name == "50% discount"
# Test 2 - Search with _ character
result = TagService.get_tags("app", tenant.id, keyword="test_data")
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="test_data")
assert len(result) == 1
assert result[0].name == "test_data_tag"
# Test 3 - Search with \ character
result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag")
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="path\\to\\tag")
assert len(result) == 1
assert result[0].name == "path\\to\\tag"
# Test 4 - Search with % should NOT match 100% (verifies escaping works)
result = TagService.get_tags("app", tenant.id, keyword="50%")
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="50%")
assert len(result) == 1
assert all("50%" in item.name for item in result)
@ -405,7 +405,7 @@ class TestTagService:
)
# Act: Execute the method under test
result = TagService.get_tags("knowledge", tenant.id)
result = TagService.get_tags(db_session_with_containers, "knowledge", tenant.id)
# Assert: Verify the expected outcomes
assert result is not None

View File

@ -2,6 +2,14 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from sqlalchemy.orm import Session
class SessionMatcher:
def __eq__(self, other):
return isinstance(other, Session)
from flask import Flask
from werkzeug.exceptions import Forbidden
@ -125,7 +133,7 @@ class TestTagListApi:
):
result, status = method(api, "tenant-1")
get_tags_mock.assert_called_once_with("snippet", "tenant-1", None)
get_tags_mock.assert_called_once_with(SessionMatcher(), "snippet", "tenant-1", None)
assert status == 200
assert result == [{"id": "1", "name": "snippet-tag", "type": "snippet", "binding_count": "1"}]