mirror of
https://github.com/langgenius/dify.git
synced 2026-06-16 14:01:10 +08:00
refactor: TagService to accept db.session explicitly (#37416)
This commit is contained in:
parent
c6b3e525d1
commit
e0773c4d8f
@ -16,6 +16,7 @@ from controllers.console.wraps import (
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
@ -101,7 +102,7 @@ class TagListApi(Resource):
|
||||
def get(self, current_tenant_id: str):
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
tags = TagService.get_tags(db.session(), param.type, current_tenant_id, param.keyword)
|
||||
|
||||
serialized_tags = [
|
||||
TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags
|
||||
|
||||
@ -22,6 +22,7 @@ from controllers.service_api.wraps import (
|
||||
)
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
@ -608,7 +609,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
tags = TagService.get_tags(db.session(), "knowledge", cid)
|
||||
return dump_response(KnowledgeTagListResponse, tags), 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
|
||||
@ -6,6 +6,7 @@ from flask_login import current_user
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -38,7 +39,7 @@ class TagBindingDeletePayload(BaseModel):
|
||||
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
|
||||
def get_tags(session: Session, tag_type: str, current_tenant_id: str, keyword: str | None = None):
|
||||
stmt = (
|
||||
select(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
@ -50,7 +51,7 @@ class TagService:
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
stmt = stmt.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
stmt = stmt.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = list(db.session.execute(stmt.order_by(Tag.created_at.desc())).all())
|
||||
results: list = list(session.execute(stmt.order_by(Tag.created_at.desc())).all())
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -23,6 +23,12 @@ from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
||||
class SessionMatcher:
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Session)
|
||||
|
||||
|
||||
import services
|
||||
from controllers.service_api.dataset.dataset import (
|
||||
DatasetCreatePayload,
|
||||
@ -998,7 +1004,7 @@ class TestDatasetTagsApiGet:
|
||||
|
||||
assert status == 200
|
||||
assert response == [{"id": "tag-1", "name": "Test Tag", "type": "knowledge", "binding_count": "0"}]
|
||||
mock_tag_svc.get_tags.assert_called_once_with("knowledge", "tenant-1")
|
||||
mock_tag_svc.get_tags.assert_called_once_with(SessionMatcher(), "knowledge", "tenant-1")
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
def test_list_tags_from_db(
|
||||
|
||||
@ -246,7 +246,7 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tags("knowledge", tenant.id)
|
||||
result = TagService.get_tags(db_session_with_containers, "knowledge", tenant.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -299,7 +299,7 @@ class TestTagService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test with keyword filter
|
||||
result = TagService.get_tags("app", tenant.id, keyword="development")
|
||||
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="development")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -310,7 +310,7 @@ class TestTagService:
|
||||
assert "development" in tag_result.name.lower()
|
||||
|
||||
# Verify no results for non-matching keyword
|
||||
result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
|
||||
result_no_match = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="nonexistent")
|
||||
assert len(result_no_match) == 0
|
||||
|
||||
def test_get_tags_with_special_characters_in_keyword(
|
||||
@ -371,22 +371,22 @@ class TestTagService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Test 1 - Search with % character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="50%")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "50% discount"
|
||||
|
||||
# Test 2 - Search with _ character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="test_data")
|
||||
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="test_data")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "test_data_tag"
|
||||
|
||||
# Test 3 - Search with \ character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag")
|
||||
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="path\\to\\tag")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "path\\to\\tag"
|
||||
|
||||
# Test 4 - Search with % should NOT match 100% (verifies escaping works)
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="50%")
|
||||
assert len(result) == 1
|
||||
assert all("50%" in item.name for item in result)
|
||||
|
||||
@ -405,7 +405,7 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tags("knowledge", tenant.id)
|
||||
result = TagService.get_tags(db_session_with_containers, "knowledge", tenant.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
|
||||
@ -2,6 +2,14 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class SessionMatcher:
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Session)
|
||||
|
||||
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -125,7 +133,7 @@ class TestTagListApi:
|
||||
):
|
||||
result, status = method(api, "tenant-1")
|
||||
|
||||
get_tags_mock.assert_called_once_with("snippet", "tenant-1", None)
|
||||
get_tags_mock.assert_called_once_with(SessionMatcher(), "snippet", "tenant-1", None)
|
||||
assert status == 200
|
||||
assert result == [{"id": "1", "name": "snippet-tag", "type": "snippet", "binding_count": "1"}]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user