mirror of
https://github.com/langgenius/dify.git
synced 2026-06-11 02:31:13 +08:00
fix(api): require all selected tags in list filters (#37272)
This commit is contained in:
parent
9ac71329a4
commit
e3cfc4d40f
@ -148,7 +148,7 @@ class AppService:
|
||||
escaped_name = escape_like_pattern(name)
|
||||
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
|
||||
if params.tag_ids and len(params.tag_ids) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids)
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, match_all=True)
|
||||
if target_ids and len(target_ids) > 0:
|
||||
filters.append(App.id.in_(target_ids))
|
||||
else:
|
||||
|
||||
@ -295,6 +295,7 @@ class DatasetService:
|
||||
"knowledge",
|
||||
tenant_id,
|
||||
tag_ids,
|
||||
match_all=True,
|
||||
)
|
||||
else:
|
||||
target_ids = []
|
||||
|
||||
@ -229,7 +229,7 @@ class SnippetService:
|
||||
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
|
||||
|
||||
if tag_ids:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids)
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, match_all=True)
|
||||
if target_ids:
|
||||
stmt = stmt.where(CustomizedSnippet.id.in_(target_ids))
|
||||
else:
|
||||
|
||||
@ -54,19 +54,43 @@ class TagService:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
|
||||
def get_target_ids_by_tag_ids(
|
||||
tag_type: str, current_tenant_id: str, tag_ids: list[str], *, match_all: bool = False
|
||||
):
|
||||
"""
|
||||
Return target IDs bound to tags for the given tenant and resource type.
|
||||
|
||||
By default this preserves the legacy "match any tag" behavior and returns one target ID per matching
|
||||
binding. When match_all is enabled, every requested tag must exist for the tenant/type and each returned
|
||||
target must be bound to all requested tags.
|
||||
"""
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if not tag_ids or len(tag_ids) == 0:
|
||||
return []
|
||||
# Deduplicate repeated query params so match_all counts each requested tag once.
|
||||
requested_tag_ids = list(dict.fromkeys(tag_ids))
|
||||
tags = db.session.scalars(
|
||||
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
select(Tag.id).where(
|
||||
Tag.id.in_(requested_tag_ids),
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type,
|
||||
)
|
||||
).all()
|
||||
if not tags:
|
||||
return []
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
tag_ids = list(tags)
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if not tag_ids or len(tag_ids) == 0:
|
||||
return []
|
||||
if match_all:
|
||||
if len(tag_ids) != len(requested_tag_ids):
|
||||
return []
|
||||
return db.session.scalars(
|
||||
select(TagBinding.target_id)
|
||||
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
|
||||
.group_by(TagBinding.target_id)
|
||||
.having(func.count(sa.distinct(TagBinding.tag_id)) == len(tag_ids))
|
||||
).all()
|
||||
tag_bindings = db.session.scalars(
|
||||
select(TagBinding.target_id).where(
|
||||
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
|
||||
|
||||
@ -406,7 +406,7 @@ class TestAppService:
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
|
||||
# Verify tag service was called
|
||||
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
|
||||
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"], match_all=True)
|
||||
|
||||
# Verify results
|
||||
assert paginated_apps is not None
|
||||
|
||||
@ -289,14 +289,24 @@ class TestDatasetServiceGetDatasets:
|
||||
tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(
|
||||
db_session_with_containers, tenant.id, account.id, dataset_2.id
|
||||
)
|
||||
db_session_with_containers.add(
|
||||
TagBinding(
|
||||
tenant_id=tenant.id,
|
||||
tag_id=tag_2.id,
|
||||
target_id=dataset_1.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
tag_ids = [tag_1.id, tag_2.id]
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 2
|
||||
assert total == 2
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
assert datasets[0].id == dataset_1.id
|
||||
|
||||
def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session):
|
||||
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
|
||||
|
||||
@ -492,6 +492,57 @@ class TestTagService:
|
||||
assert len(result) == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_target_ids_by_tag_ids_match_all(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test target ID retrieval when every requested tag must be bound to the same target.
|
||||
|
||||
This test verifies:
|
||||
- Targets with only one requested tag are excluded
|
||||
- Targets with all requested tags are returned once
|
||||
- Missing requested tags make the filter unsatisfiable
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
tags = self._create_test_tags(
|
||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2
|
||||
)
|
||||
dataset_with_all_tags = self._create_test_dataset(
|
||||
db_session_with_containers, mock_external_service_dependencies, tenant.id
|
||||
)
|
||||
dataset_with_one_tag = self._create_test_dataset(
|
||||
db_session_with_containers, mock_external_service_dependencies, tenant.id
|
||||
)
|
||||
self._create_test_tag_bindings(
|
||||
db_session_with_containers,
|
||||
mock_external_service_dependencies,
|
||||
tags,
|
||||
dataset_with_all_tags.id,
|
||||
tenant.id,
|
||||
)
|
||||
self._create_test_tag_bindings(
|
||||
db_session_with_containers,
|
||||
mock_external_service_dependencies,
|
||||
tags[:1],
|
||||
dataset_with_one_tag.id,
|
||||
tenant.id,
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, match_all=True)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == [dataset_with_all_tags.id]
|
||||
|
||||
missing_tag_result = TagService.get_target_ids_by_tag_ids(
|
||||
"knowledge", tenant.id, [tags[0].id, str(uuid.uuid4())], match_all=True
|
||||
)
|
||||
assert missing_tag_result == []
|
||||
|
||||
def test_get_target_ids_by_tag_ids_no_matching_tags(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -94,12 +94,14 @@ def test_validate_snippet_graph_forbidden_nodes_raises_with_node_details() -> No
|
||||
|
||||
|
||||
def test_get_snippets_returns_empty_when_tag_filter_has_no_targets(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", Mock(return_value=[]))
|
||||
get_target_ids = Mock(return_value=[])
|
||||
monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", get_target_ids)
|
||||
service = SnippetService.__new__(SnippetService)
|
||||
|
||||
result = service.get_snippets(tenant_id="tenant-1", tag_ids=["tag-1"])
|
||||
|
||||
assert result == ([], 0, False)
|
||||
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True)
|
||||
|
||||
|
||||
def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -114,9 +116,10 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa
|
||||
)
|
||||
service = SnippetService.__new__(SnippetService)
|
||||
service._session_maker = _session_maker(session)
|
||||
get_target_ids = Mock(return_value=["snippet-1", "snippet-2", "snippet-3"])
|
||||
monkeypatch.setattr(
|
||||
"services.snippet_service.TagService.get_target_ids_by_tag_ids",
|
||||
Mock(return_value=["snippet-1", "snippet-2", "snippet-3"]),
|
||||
get_target_ids,
|
||||
)
|
||||
|
||||
result, total, has_more = service.get_snippets(
|
||||
@ -132,6 +135,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa
|
||||
assert result == snippets[:2]
|
||||
assert total == 3
|
||||
assert has_more is True
|
||||
get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True)
|
||||
session.scalar.assert_called_once()
|
||||
session.scalars.assert_called_once()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user