diff --git a/api/services/app_service.py b/api/services/app_service.py index 58727e658c..e9741b8e23 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -148,7 +148,7 @@ class AppService: escaped_name = escape_like_pattern(name) filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) if params.tag_ids and len(params.tag_ids) > 0: - target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids) + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids, match_all=True) if target_ids and len(target_ids) > 0: filters.append(App.id.in_(target_ids)) else: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ddf35e35e6..af50a0e318 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -295,6 +295,7 @@ class DatasetService: "knowledge", tenant_id, tag_ids, + match_all=True, ) else: target_ids = [] diff --git a/api/services/snippet_service.py b/api/services/snippet_service.py index aa46626e56..9f16d41204 100644 --- a/api/services/snippet_service.py +++ b/api/services/snippet_service.py @@ -229,7 +229,7 @@ class SnippetService: stmt = stmt.where(CustomizedSnippet.created_by.in_(creators)) if tag_ids: - target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids) + target_ids = TagService.get_target_ids_by_tag_ids("snippet", tenant_id, tag_ids, match_all=True) if target_ids: stmt = stmt.where(CustomizedSnippet.id.in_(target_ids)) else: diff --git a/api/services/tag_service.py b/api/services/tag_service.py index f1e5c3fc56..404ccb0d75 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -54,19 +54,43 @@ class TagService: return results @staticmethod - def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list): + def get_target_ids_by_tag_ids( + tag_type: str, current_tenant_id: str, tag_ids: list[str], *, match_all: bool = False + ): + """ + Return target IDs bound to tags for the given tenant and resource type. + + By default this preserves the legacy "match any tag" behavior and returns one target ID per matching + binding. When match_all is enabled, every requested tag must exist for the tenant/type and each returned + target must be bound to all requested tags. + """ # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] + # Deduplicate repeated query params so match_all counts each requested tag once. + requested_tag_ids = list(dict.fromkeys(tag_ids)) tags = db.session.scalars( - select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + select(Tag.id).where( + Tag.id.in_(requested_tag_ids), + Tag.tenant_id == current_tenant_id, + Tag.type == tag_type, + ) ).all() if not tags: return [] - tag_ids = [tag.id for tag in tags] + tag_ids = list(tags) # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] + if match_all: + if len(tag_ids) != len(requested_tag_ids): + return [] + return db.session.scalars( + select(TagBinding.target_id) + .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .group_by(TagBinding.target_id) + .having(func.count(sa.distinct(TagBinding.tag_id)) == len(tag_ids)) + ).all() tag_bindings = db.session.scalars( select(TagBinding.target_id).where( TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 03ea1d1fa0..56d643e4c1 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -406,7 +406,7 @@ class TestAppService: paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params) # Verify tag service was called - mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"]) + mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"], match_all=True) # Verify results assert paginated_apps is not None diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 0c610311bb..4e2bf9fc10 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -289,14 +289,24 @@ class TestDatasetServiceGetDatasets: tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding( db_session_with_containers, tenant.id, account.id, dataset_2.id ) + db_session_with_containers.add( + TagBinding( + tenant_id=tenant.id, + tag_id=tag_2.id, + target_id=dataset_1.id, + created_by=account.id, + ) + ) + db_session_with_containers.commit() tag_ids = [tag_1.id, tag_2.id] # Act datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids) # Assert - assert len(datasets) == 2 - assert total == 2 + assert len(datasets) == 1 + assert total == 1 + assert datasets[0].id == dataset_1.id def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session): """Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets.""" diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index f088cc964d..f4854d1072 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -492,6 +492,57 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) + def test_get_target_ids_by_tag_ids_match_all( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test target ID retrieval when every requested tag must be bound to the same target. + + This test verifies: + - Targets with only one requested tag are excluded + - Targets with all requested tags are returned once + - Missing requested tags make the filter unsatisfiable + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2 + ) + dataset_with_all_tags = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, tenant.id + ) + dataset_with_one_tag = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, tenant.id + ) + self._create_test_tag_bindings( + db_session_with_containers, + mock_external_service_dependencies, + tags, + dataset_with_all_tags.id, + tenant.id, + ) + self._create_test_tag_bindings( + db_session_with_containers, + mock_external_service_dependencies, + tags[:1], + dataset_with_one_tag.id, + tenant.id, + ) + + # Act: Execute the method under test + tag_ids = [tag.id for tag in tags] + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids, match_all=True) + + # Assert: Verify the expected outcomes + assert result == [dataset_with_all_tags.id] + + missing_tag_result = TagService.get_target_ids_by_tag_ids( + "knowledge", tenant.id, [tags[0].id, str(uuid.uuid4())], match_all=True + ) + assert missing_tag_result == [] + def test_get_target_ids_by_tag_ids_no_matching_tags( self, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/unit_tests/services/test_snippet_service.py b/api/tests/unit_tests/services/test_snippet_service.py index 31146790d3..7cbe773e41 100644 --- a/api/tests/unit_tests/services/test_snippet_service.py +++ b/api/tests/unit_tests/services/test_snippet_service.py @@ -94,12 +94,14 @@ def test_validate_snippet_graph_forbidden_nodes_raises_with_node_details() -> No def test_get_snippets_returns_empty_when_tag_filter_has_no_targets(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", Mock(return_value=[])) + get_target_ids = Mock(return_value=[]) + monkeypatch.setattr("services.snippet_service.TagService.get_target_ids_by_tag_ids", get_target_ids) service = SnippetService.__new__(SnippetService) result = service.get_snippets(tenant_id="tenant-1", tag_ids=["tag-1"]) assert result == ([], 0, False) + get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True) def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPatch) -> None: @@ -114,9 +116,10 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa ) service = SnippetService.__new__(SnippetService) service._session_maker = _session_maker(session) + get_target_ids = Mock(return_value=["snippet-1", "snippet-2", "snippet-3"]) monkeypatch.setattr( "services.snippet_service.TagService.get_target_ids_by_tag_ids", - Mock(return_value=["snippet-1", "snippet-2", "snippet-3"]), + get_target_ids, ) result, total, has_more = service.get_snippets( @@ -132,6 +135,7 @@ def test_get_snippets_applies_filters_and_paginates(monkeypatch: pytest.MonkeyPa assert result == snippets[:2] assert total == 3 assert has_more is True + get_target_ids.assert_called_once_with("snippet", "tenant-1", ["tag-1"], match_all=True) session.scalar.assert_called_once() session.scalars.assert_called_once()