fix(api): require all selected tags in list filters (#37272)

This commit is contained in:
非法操作 2026-06-10 16:20:13 +08:00 committed by GitHub
parent 9ac71329a4
commit e3cfc4d40f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 100 additions and 10 deletions

View File

@ -148,7 +148,7 @@ class AppService:
escaped_name = escape_like_pattern(name)
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
if params.tag_ids and len(params.tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids)
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, match_all=True)
if target_ids and len(target_ids) > 0:
filters.append(App.id.in_(target_ids))
else:

View File

@ -295,6 +295,7 @@ class DatasetService:
"knowledge",
tenant_id,
tag_ids,
match_all=True,
)
else:
target_ids = []

View File

@ -229,7 +229,7 @@ class SnippetService:
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids)
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, match_all=True)
if target_ids:
stmt = stmt.where(CustomizedSnippet.id.in_(target_ids))
else:

View File

@ -54,19 +54,43 @@ class TagService:
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
def get_target_ids_by_tag_ids(
tag_type: str, current_tenant_id: str, tag_ids: list[str], *, match_all: bool = False
):
"""
Return target IDs bound to tags for the given tenant and resource type.
By default this preserves the legacy "match any tag" behavior and returns one target ID per matching
binding. When match_all is enabled, every requested tag must exist for the tenant/type and each returned
target must be bound to all requested tags.
"""
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
# Deduplicate repeated query params so match_all counts each requested tag once.
requested_tag_ids = list(dict.fromkeys(tag_ids))
tags = db.session.scalars(
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
select(Tag.id).where(
Tag.id.in_(requested_tag_ids),
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type,
)
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
tag_ids = list(tags)
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
if match_all:
if len(tag_ids) != len(requested_tag_ids):
return []
return db.session.scalars(
select(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.group_by(TagBinding.target_id)
.having(func.count(sa.distinct(TagBinding.tag_id)) == len(tag_ids))
).all()
tag_bindings = db.session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id

View File

@ -406,7 +406,7 @@ class TestAppService:
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
# Verify tag service was called
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"], match_all=True)
# Verify results
assert paginated_apps is not None

View File

@ -289,14 +289,24 @@ class TestDatasetServiceGetDatasets:
tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(
db_session_with_containers, tenant.id, account.id, dataset_2.id
)
db_session_with_containers.add(
TagBinding(
tenant_id=tenant.id,
tag_id=tag_2.id,
target_id=dataset_1.id,
created_by=account.id,
)
)
db_session_with_containers.commit()
tag_ids = [tag_1.id, tag_2.id]
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
# Assert
assert len(datasets) == 2
assert total == 2
assert len(datasets) == 1
assert total == 1
assert datasets[0].id == dataset_1.id
def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session):
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""

View File

@ -492,6 +492,57 @@ class TestTagService:
assert len(result) == 0
assert isinstance(result, list)
def test_get_target_ids_by_tag_ids_match_all(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test target ID retrieval when every requested tag must be bound to the same target.
This test verifies:
- Targets with only one requested tag are excluded
- Targets with all requested tags are returned once
- Missing requested tags make the filter unsatisfiable
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
tags = self._create_test_tags(
db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2
)
dataset_with_all_tags = self._create_test_dataset(
db_session_with_containers, mock_external_service_dependencies, tenant.id
)
dataset_with_one_tag = self._create_test_dataset(
db_session_with_containers, mock_external_service_dependencies, tenant.id
)
self._create_test_tag_bindings(
db_session_with_containers,
mock_external_service_dependencies,
tags,
dataset_with_all_tags.id,
tenant.id,
)
self._create_test_tag_bindings(
db_session_with_containers,
mock_external_service_dependencies,
tags[:1],
dataset_with_one_tag.id,
tenant.id,
)
# Act: Execute the method under test
tag_ids = [tag.id for tag in tags]
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, match_all=True)
# Assert: Verify the expected outcomes
assert result == [dataset_with_all_tags.id]
missing_tag_result = TagService.get_target_ids_by_tag_ids(
"knowledge", tenant.id, [tags[0].id, str(uuid.uuid4())], match_all=True
)
assert missing_tag_result == []
def test_get_target_ids_by_tag_ids_no_matching_tags(
self, db_session_with_containers: Session, mock_external_service_dependencies
):

View File

@ -94,12 +94,14 @@ def test_validate_snippet_graph_forbidden_nodes_raises_with_node_details() -> No
def test_get_snippets_returns_empty_when_tag_filter_has_no_targets(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", Mock(return_value=[]))
get_target_ids = Mock(return_value=[])
monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", get_target_ids)
service = SnippetService.__new__(SnippetService)
result = service.get_snippets(tenant_id="tenant-1", tag_ids=["tag-1"])
assert result == ([], 0, False)
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True)
def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPatch) -> None:
@ -114,9 +116,10 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa
)
service = SnippetService.__new__(SnippetService)
service._session_maker = _session_maker(session)
get_target_ids = Mock(return_value=["snippet-1", "snippet-2", "snippet-3"])
monkeypatch.setattr(
"services.snippet_service.TagService.get_target_ids_by_tag_ids",
Mock(return_value=["snippet-1", "snippet-2", "snippet-3"]),
get_target_ids,
)
result, total, has_more = service.get_snippets(
@ -132,6 +135,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa
assert result == snippets[:2]
assert total == 3
assert has_more is True
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True)
session.scalar.assert_called_once()
session.scalars.assert_called_once()