diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index a9a1157a29..82a713f1c6 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -16,6 +16,7 @@ from controllers.console.wraps import ( with_current_tenant_id, with_current_user, ) +from extensions.ext_database import db from fields.base import ResponseModel from libs.login import login_required from models import Account @@ -101,7 +102,7 @@ class TagListApi(Resource): def get(self, current_tenant_id: str): raw_args = request.args.to_dict() param = TagListQueryParam.model_validate(raw_args) - tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) + tags = TagService.get_tags(db.session(), param.type, current_tenant_id, param.keyword) serialized_tags = [ TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 89b8a0816f..c307063b3e 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -22,6 +22,7 @@ from controllers.service_api.wraps import ( ) from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.index_processor.constant.index_type import IndexTechniqueType +from extensions.ext_database import db from fields.base import ResponseModel from fields.dataset_fields import DatasetDetailResponse from graphon.model_runtime.entities.model_entities import ModelType @@ -608,7 +609,7 @@ class DatasetTagsApi(DatasetApiResource): assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None - tags = TagService.get_tags("knowledge", cid) + tags = TagService.get_tags(db.session(), "knowledge", cid) return dump_response(KnowledgeTagListResponse, tags), 200 @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 404ccb0d75..20f9a2c73d 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -6,6 +6,7 @@ from flask_login import current_user from pydantic import BaseModel, Field from sqlalchemy import delete, func, select from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from extensions.ext_database import db @@ -38,7 +39,7 @@ class TagBindingDeletePayload(BaseModel): class TagService: @staticmethod - def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None): + def get_tags(session: Session, tag_type: str, current_tenant_id: str, keyword: str | None = None): stmt = ( select(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) @@ -50,7 +51,7 @@ class TagService: escaped_keyword = escape_like_pattern(keyword) stmt = stmt.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) stmt = stmt.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) - results: list = list(db.session.execute(stmt.order_by(Tag.created_at.desc())).all()) + results: list = list(session.execute(stmt.order_by(Tag.created_at.desc())).all()) return results @staticmethod diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 91b0055e06..642dd3ab62 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -23,6 +23,12 @@ from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound + +class SessionMatcher: + def __eq__(self, other): + return isinstance(other, Session) + + import services from controllers.service_api.dataset.dataset import ( DatasetCreatePayload, @@ -998,7 +1004,7 @@ class TestDatasetTagsApiGet: assert status == 200 assert response == [{"id": "tag-1", "name": "Test Tag", "type": "knowledge", "binding_count": "0"}] - mock_tag_svc.get_tags.assert_called_once_with("knowledge", "tenant-1") + mock_tag_svc.get_tags.assert_called_once_with(SessionMatcher(), "knowledge", "tenant-1") @patch("controllers.service_api.dataset.dataset.current_user") def test_list_tags_from_db( diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index f4854d1072..517d5d2ed4 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -246,7 +246,7 @@ class TestTagService: ) # Act: Execute the method under test - result = TagService.get_tags("knowledge", tenant.id) + result = TagService.get_tags(db_session_with_containers, "knowledge", tenant.id) # Assert: Verify the expected outcomes assert result is not None @@ -299,7 +299,7 @@ class TestTagService: db_session_with_containers.commit() # Act: Execute the method under test with keyword filter - result = TagService.get_tags("app", tenant.id, keyword="development") + result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="development") # Assert: Verify the expected outcomes assert result is not None @@ -310,7 +310,7 @@ class TestTagService: assert "development" in tag_result.name.lower() # Verify no results for non-matching keyword - result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") + result_no_match = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="nonexistent") assert len(result_no_match) == 0 def test_get_tags_with_special_characters_in_keyword( @@ -371,22 +371,22 @@ class TestTagService: db_session_with_containers.commit() # Act & Assert: Test 1 - Search with % character - result = TagService.get_tags("app", tenant.id, keyword="50%") + result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="50%") assert len(result) == 1 assert result[0].name == "50% discount" # Test 2 - Search with _ character - result = TagService.get_tags("app", tenant.id, keyword="test_data") + result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="test_data") assert len(result) == 1 assert result[0].name == "test_data_tag" # Test 3 - Search with \ character - result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag") + result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="path\\to\\tag") assert len(result) == 1 assert result[0].name == "path\\to\\tag" # Test 4 - Search with % should NOT match 100% (verifies escaping works) - result = TagService.get_tags("app", tenant.id, keyword="50%") + result = TagService.get_tags(db_session_with_containers, "app", tenant.id, keyword="50%") assert len(result) == 1 assert all("50%" in item.name for item in result) @@ -405,7 +405,7 @@ class TestTagService: ) # Act: Execute the method under test - result = TagService.get_tags("knowledge", tenant.id) + result = TagService.get_tags(db_session_with_containers, "knowledge", tenant.id) # Assert: Verify the expected outcomes assert result is not None diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 3630f1bfec..dc3dd00a6c 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -2,6 +2,14 @@ from types import SimpleNamespace from unittest.mock import MagicMock, PropertyMock, patch import pytest +from sqlalchemy.orm import Session + + +class SessionMatcher: + def __eq__(self, other): + return isinstance(other, Session) + + from flask import Flask from werkzeug.exceptions import Forbidden @@ -125,7 +133,7 @@ class TestTagListApi: ): result, status = method(api, "tenant-1") - get_tags_mock.assert_called_once_with("snippet", "tenant-1", None) + get_tags_mock.assert_called_once_with(SessionMatcher(), "snippet", "tenant-1", None) assert status == 200 assert result == [{"id": "1", "name": "snippet-tag", "type": "snippet", "binding_count": "1"}]