diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py new file mode 100644 index 0000000000..f605a286ed --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -0,0 +1,643 @@ +""" +Comprehensive integration tests for DatasetService retrieval/list methods. + +This test suite covers: +- get_datasets - pagination, search, filtering, permissions +- get_dataset - single dataset retrieval +- get_datasets_by_ids - bulk retrieval +- get_process_rules - dataset processing rules +- get_dataset_queries - dataset query history +- get_related_apps - apps using the dataset +""" + +import json +from uuid import uuid4 + +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetPermission, + DatasetPermissionEnum, + DatasetProcessRule, + DatasetQuery, +) +from models.model import Tag, TagBinding +from services.dataset_service import DatasetService, DocumentService + + +class DatasetRetrievalTestDataFactory: + """Factory class for creating database-backed test data for dataset retrieval integration tests.""" + + @staticmethod + def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.NORMAL) -> tuple[Account, Tenant]: + """Create an account and tenant with the specified role.""" + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant( + name=f"tenant-{uuid4()}", + status="normal", + ) + db.session.add_all([account, tenant]) + db.session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db.session.add(join) + db.session.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_account_in_tenant(tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER) -> Account: + """Create an account and add it to an existing tenant.""" + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db.session.add(join) + db.session.commit() + + account.current_tenant = tenant + return account + + @staticmethod + def create_dataset( + tenant_id: str, + created_by: str, + name: str = "Test Dataset", + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + ) -> Dataset: + """Create a dataset.""" + dataset = Dataset( + tenant_id=tenant_id, + name=name, + description="desc", + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=created_by, + permission=permission, + provider="vendor", + retrieval_model={"top_k": 2}, + ) + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def create_dataset_permission(dataset_id: str, tenant_id: str, account_id: str) -> DatasetPermission: + """Create a dataset permission.""" + permission = DatasetPermission( + dataset_id=dataset_id, + tenant_id=tenant_id, + account_id=account_id, + has_permission=True, + ) + db.session.add(permission) + db.session.commit() + return permission + + @staticmethod + def create_process_rule(dataset_id: str, created_by: str, mode: str, rules: dict) -> DatasetProcessRule: + """Create a dataset process rule.""" + process_rule = DatasetProcessRule( + dataset_id=dataset_id, + created_by=created_by, + mode=mode, + rules=json.dumps(rules), + ) + db.session.add(process_rule) + db.session.commit() + return process_rule + + @staticmethod + def create_dataset_query(dataset_id: str, created_by: str, content: str) -> DatasetQuery: + """Create a dataset query.""" + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=content, + source="web", + source_app_id=None, + created_by_role="account", + created_by=created_by, + ) + db.session.add(dataset_query) + db.session.commit() + return dataset_query + + @staticmethod + def create_app_dataset_join(dataset_id: str) -> AppDatasetJoin: + """Create an app-dataset join.""" + join = AppDatasetJoin( + app_id=str(uuid4()), + dataset_id=dataset_id, + ) + db.session.add(join) + db.session.commit() + return join + + @staticmethod + def create_tag_binding(tenant_id: str, created_by: str, target_id: str) -> Tag: + """Create a knowledge tag and bind it to the target dataset.""" + tag = Tag( + tenant_id=tenant_id, + type="knowledge", + name=f"tag-{uuid4()}", + created_by=created_by, + ) + db.session.add(tag) + db.session.flush() + + binding = TagBinding( + tenant_id=tenant_id, + tag_id=tag.id, + target_id=target_id, + created_by=created_by, + ) + db.session.add(binding) + db.session.commit() + return tag + + +class TestDatasetServiceGetDatasets: + """ + Comprehensive integration tests for DatasetService.get_datasets method. + + This test suite covers: + - Pagination + - Search functionality + - Tag filtering + - Permission-based filtering (ONLY_ME, ALL_TEAM, PARTIAL_TEAM) + - Role-based filtering (OWNER, DATASET_OPERATOR, NORMAL) + - include_all flag + """ + + # ==================== Basic Retrieval Tests ==================== + + def test_get_datasets_basic_pagination(self, db_session_with_containers): + """Test basic pagination without user or filters.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + page = 1 + per_page = 20 + + for i in range(5): + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + name=f"Dataset {i}", + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Act + datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id) + + # Assert + assert len(datasets) == 5 + assert total == 5 + + def test_get_datasets_with_search(self, db_session_with_containers): + """Test get_datasets with search keyword.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + page = 1 + per_page = 20 + search = "test" + + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + name="Test Dataset", + permission=DatasetPermissionEnum.ALL_TEAM, + ) + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + name="Another Dataset", + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Act + datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, search=search) + + # Assert + assert len(datasets) == 1 + assert total == 1 + + def test_get_datasets_with_tag_filtering(self, db_session_with_containers): + """Test get_datasets with tag_ids filtering.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + page = 1 + per_page = 20 + + dataset_1 = DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + dataset_2 = DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_1.id) + tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_2.id) + tag_ids = [tag_1.id, tag_2.id] + + # Act + datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids) + + # Assert + assert len(datasets) == 2 + assert total == 2 + + def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers): + """Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + page = 1 + per_page = 20 + tag_ids = [] + + for i in range(3): + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + name=f"dataset-{i}", + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Act + datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, tag_ids=tag_ids) + + # Assert + # When tag_ids is empty, tag filtering is skipped, so normal query results are returned + assert len(datasets) == 3 + assert total == 3 + + # ==================== Permission-Based Filtering Tests ==================== + + def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers): + """Test that without user, only ALL_TEAM datasets are shown.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + page = 1 + per_page = 20 + + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant.id, user=None) + + # Assert + assert len(datasets) == 1 + assert total == 1 + + def test_get_datasets_owner_with_include_all(self, db_session_with_containers): + """Test that OWNER with include_all=True sees all datasets.""" + # Arrange + owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + + for i, permission in enumerate( + [DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM] + ): + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=owner.id, + name=f"dataset-{i}", + permission=permission, + ) + + # Act + datasets, total = DatasetService.get_datasets( + page=1, + per_page=20, + tenant_id=tenant.id, + user=owner, + include_all=True, + ) + + # Assert + assert len(datasets) == 3 + assert total == 3 + + def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers): + """Test that normal user sees ONLY_ME datasets they created.""" + # Arrange + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) + + # Assert + assert len(datasets) == 1 + assert total == 1 + + def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers): + """Test that normal user sees ALL_TEAM datasets.""" + # Arrange + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Act + datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) + + # Assert + assert len(datasets) == 1 + assert total == 1 + + def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers): + """Test that normal user sees PARTIAL_TEAM datasets they have permission for.""" + # Arrange + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + + dataset = DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, user.id) + + # Act + datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) + + # Assert + assert len(datasets) == 1 + assert total == 1 + + def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers): + """Test that DATASET_OPERATOR only sees datasets they have explicit permission for.""" + # Arrange + operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + + dataset = DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, operator.id) + + # Act + datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) + + # Assert + assert len(datasets) == 1 + assert total == 1 + + def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers): + """Test that DATASET_OPERATOR without permissions returns empty result.""" + # Arrange + operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + DatasetRetrievalTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Act + datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) + + # Assert + assert datasets == [] + assert total == 0 + + +class TestDatasetServiceGetDataset: + """Comprehensive integration tests for DatasetService.get_dataset method.""" + + def test_get_dataset_success(self, db_session_with_containers): + """Test successful retrieval of a single dataset.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + + # Act + result = DatasetService.get_dataset(dataset.id) + + # Assert + assert result is not None + assert result.id == dataset.id + + def test_get_dataset_not_found(self, db_session_with_containers): + """Test retrieval when dataset doesn't exist.""" + # Arrange + dataset_id = str(uuid4()) + + # Act + result = DatasetService.get_dataset(dataset_id) + + # Assert + assert result is None + + +class TestDatasetServiceGetDatasetsByIds: + """Comprehensive integration tests for DatasetService.get_datasets_by_ids method.""" + + def test_get_datasets_by_ids_success(self, db_session_with_containers): + """Test successful bulk retrieval of datasets by IDs.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + datasets = [ + DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) for _ in range(3) + ] + dataset_ids = [dataset.id for dataset in datasets] + + # Act + result_datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant.id) + + # Assert + assert len(result_datasets) == 3 + assert total == 3 + assert all(dataset.id in dataset_ids for dataset in result_datasets) + + def test_get_datasets_by_ids_empty_list(self, db_session_with_containers): + """Test get_datasets_by_ids with empty list returns empty result.""" + # Arrange + tenant_id = str(uuid4()) + dataset_ids = [] + + # Act + datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id) + + # Assert + assert datasets == [] + assert total == 0 + + def test_get_datasets_by_ids_none_list(self, db_session_with_containers): + """Test get_datasets_by_ids with None returns empty result.""" + # Arrange + tenant_id = str(uuid4()) + + # Act + datasets, total = DatasetService.get_datasets_by_ids(None, tenant_id) + + # Assert + assert datasets == [] + assert total == 0 + + +class TestDatasetServiceGetProcessRules: + """Comprehensive integration tests for DatasetService.get_process_rules method.""" + + def test_get_process_rules_with_existing_rule(self, db_session_with_containers): + """Test retrieval of process rules when rule exists.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + + rules_data = { + "pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}], + "segmentation": {"delimiter": "\n", "max_tokens": 500}, + } + DatasetRetrievalTestDataFactory.create_process_rule( + dataset_id=dataset.id, + created_by=account.id, + mode="custom", + rules=rules_data, + ) + + # Act + result = DatasetService.get_process_rules(dataset.id) + + # Assert + assert result["mode"] == "custom" + assert result["rules"] == rules_data + + def test_get_process_rules_without_existing_rule(self, db_session_with_containers): + """Test retrieval of process rules when no rule exists (returns defaults).""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + + # Act + result = DatasetService.get_process_rules(dataset.id) + + # Assert + assert result["mode"] == DocumentService.DEFAULT_RULES["mode"] + assert "rules" in result + assert result["rules"] == DocumentService.DEFAULT_RULES["rules"] + + +class TestDatasetServiceGetDatasetQueries: + """Comprehensive integration tests for DatasetService.get_dataset_queries method.""" + + def test_get_dataset_queries_success(self, db_session_with_containers): + """Test successful retrieval of dataset queries.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + page = 1 + per_page = 20 + + for i in range(3): + DatasetRetrievalTestDataFactory.create_dataset_query( + dataset_id=dataset.id, + created_by=account.id, + content=f"query-{i}", + ) + + # Act + queries, total = DatasetService.get_dataset_queries(dataset.id, page, per_page) + + # Assert + assert len(queries) == 3 + assert total == 3 + assert all(query.dataset_id == dataset.id for query in queries) + + def test_get_dataset_queries_empty_result(self, db_session_with_containers): + """Test retrieval when no queries exist.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + page = 1 + per_page = 20 + + # Act + queries, total = DatasetService.get_dataset_queries(dataset.id, page, per_page) + + # Assert + assert queries == [] + assert total == 0 + + +class TestDatasetServiceGetRelatedApps: + """Comprehensive integration tests for DatasetService.get_related_apps method.""" + + def test_get_related_apps_success(self, db_session_with_containers): + """Test successful retrieval of related apps.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + + for _ in range(2): + DatasetRetrievalTestDataFactory.create_app_dataset_join(dataset.id) + + # Act + result = DatasetService.get_related_apps(dataset.id) + + # Assert + assert len(result) == 2 + assert all(join.dataset_id == dataset.id for join in result) + + def test_get_related_apps_empty_result(self, db_session_with_containers): + """Test retrieval when no related apps exist.""" + # Arrange + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + + # Act + result = DatasetService.get_related_apps(dataset.id) + + # Assert + assert result == [] diff --git a/api/tests/unit_tests/services/test_dataset_service_retrieval.py b/api/tests/unit_tests/services/test_dataset_service_retrieval.py deleted file mode 100644 index caf02c159f..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_retrieval.py +++ /dev/null @@ -1,746 +0,0 @@ -""" -Comprehensive unit tests for DatasetService retrieval/list methods. - -This test suite covers: -- get_datasets - pagination, search, filtering, permissions -- get_dataset - single dataset retrieval -- get_datasets_by_ids - bulk retrieval -- get_process_rules - dataset processing rules -- get_dataset_queries - dataset query history -- get_related_apps - apps using the dataset -""" - -from unittest.mock import Mock, create_autospec, patch -from uuid import uuid4 - -import pytest - -from models.account import Account, TenantAccountRole -from models.dataset import ( - AppDatasetJoin, - Dataset, - DatasetPermission, - DatasetPermissionEnum, - DatasetProcessRule, - DatasetQuery, -) -from services.dataset_service import DatasetService, DocumentService - - -class DatasetRetrievalTestDataFactory: - """Factory class for creating test data and mock objects for dataset retrieval tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - name: str = "Test Dataset", - tenant_id: str = "tenant-123", - created_by: str = "user-123", - permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.name = name - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.permission = permission - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - tenant_id: str = "tenant-123", - role: TenantAccountRole = TenantAccountRole.NORMAL, - **kwargs, - ) -> Mock: - """Create a mock account.""" - account = create_autospec(Account, instance=True) - account.id = account_id - account.current_tenant_id = tenant_id - account.current_role = role - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_dataset_permission_mock( - dataset_id: str = "dataset-123", - account_id: str = "account-123", - **kwargs, - ) -> Mock: - """Create a mock dataset permission.""" - permission = Mock(spec=DatasetPermission) - permission.dataset_id = dataset_id - permission.account_id = account_id - for key, value in kwargs.items(): - setattr(permission, key, value) - return permission - - @staticmethod - def create_process_rule_mock( - dataset_id: str = "dataset-123", - mode: str = "automatic", - rules: dict | None = None, - **kwargs, - ) -> Mock: - """Create a mock dataset process rule.""" - process_rule = Mock(spec=DatasetProcessRule) - process_rule.dataset_id = dataset_id - process_rule.mode = mode - process_rule.rules_dict = rules or {} - for key, value in kwargs.items(): - setattr(process_rule, key, value) - return process_rule - - @staticmethod - def create_dataset_query_mock( - dataset_id: str = "dataset-123", - query_id: str = "query-123", - **kwargs, - ) -> Mock: - """Create a mock dataset query.""" - dataset_query = Mock(spec=DatasetQuery) - dataset_query.id = query_id - dataset_query.dataset_id = dataset_id - for key, value in kwargs.items(): - setattr(dataset_query, key, value) - return dataset_query - - @staticmethod - def create_app_dataset_join_mock( - app_id: str = "app-123", - dataset_id: str = "dataset-123", - **kwargs, - ) -> Mock: - """Create a mock app-dataset join.""" - join = Mock(spec=AppDatasetJoin) - join.app_id = app_id - join.dataset_id = dataset_id - for key, value in kwargs.items(): - setattr(join, key, value) - return join - - -class TestDatasetServiceGetDatasets: - """ - Comprehensive unit tests for DatasetService.get_datasets method. - - This test suite covers: - - Pagination - - Search functionality - - Tag filtering - - Permission-based filtering (ONLY_ME, ALL_TEAM, PARTIAL_TEAM) - - Role-based filtering (OWNER, DATASET_OPERATOR, NORMAL) - - include_all flag - """ - - @pytest.fixture - def mock_dependencies(self): - """Common mock setup for get_datasets tests.""" - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.db.paginate") as mock_paginate, - patch("services.dataset_service.TagService") as mock_tag_service, - ): - yield { - "db_session": mock_db, - "paginate": mock_paginate, - "tag_service": mock_tag_service, - } - - # ==================== Basic Retrieval Tests ==================== - - def test_get_datasets_basic_pagination(self, mock_dependencies): - """Test basic pagination without user or filters.""" - # Arrange - tenant_id = str(uuid4()) - page = 1 - per_page = 20 - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock( - dataset_id=f"dataset-{i}", name=f"Dataset {i}", tenant_id=tenant_id - ) - for i in range(5) - ] - mock_paginate_result.total = 5 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id) - - # Assert - assert len(datasets) == 5 - assert total == 5 - mock_dependencies["paginate"].assert_called_once() - - def test_get_datasets_with_search(self, mock_dependencies): - """Test get_datasets with search keyword.""" - # Arrange - tenant_id = str(uuid4()) - page = 1 - per_page = 20 - search = "test" - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock( - dataset_id="dataset-1", name="Test Dataset", tenant_id=tenant_id - ) - ] - mock_paginate_result.total = 1 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, search=search) - - # Assert - assert len(datasets) == 1 - assert total == 1 - mock_dependencies["paginate"].assert_called_once() - - def test_get_datasets_with_tag_filtering(self, mock_dependencies): - """Test get_datasets with tag_ids filtering.""" - # Arrange - tenant_id = str(uuid4()) - page = 1 - per_page = 20 - tag_ids = ["tag-1", "tag-2"] - - # Mock tag service - target_ids = ["dataset-1", "dataset-2"] - mock_dependencies["tag_service"].get_target_ids_by_tag_ids.return_value = target_ids - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id) - for dataset_id in target_ids - ] - mock_paginate_result.total = 2 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids) - - # Assert - assert len(datasets) == 2 - assert total == 2 - mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_called_once_with( - "knowledge", tenant_id, tag_ids - ) - - def test_get_datasets_with_empty_tag_ids(self, mock_dependencies): - """Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets.""" - # Arrange - tenant_id = str(uuid4()) - page = 1 - per_page = 20 - tag_ids = [] - - # Mock pagination result - when tag_ids is empty, tag filtering is skipped - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id) - for i in range(3) - ] - mock_paginate_result.total = 3 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids) - - # Assert - # When tag_ids is empty, tag filtering is skipped, so normal query results are returned - assert len(datasets) == 3 - assert total == 3 - # Tag service should not be called when tag_ids is empty - mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_not_called() - mock_dependencies["paginate"].assert_called_once() - - # ==================== Permission-Based Filtering Tests ==================== - - def test_get_datasets_without_user_shows_only_all_team(self, mock_dependencies): - """Test that without user, only ALL_TEAM datasets are shown.""" - # Arrange - tenant_id = str(uuid4()) - page = 1 - per_page = 20 - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock( - dataset_id="dataset-1", - tenant_id=tenant_id, - permission=DatasetPermissionEnum.ALL_TEAM, - ) - ] - mock_paginate_result.total = 1 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, user=None) - - # Assert - assert len(datasets) == 1 - mock_dependencies["paginate"].assert_called_once() - - def test_get_datasets_owner_with_include_all(self, mock_dependencies): - """Test that OWNER with include_all=True sees all datasets.""" - # Arrange - tenant_id = str(uuid4()) - user = DatasetRetrievalTestDataFactory.create_account_mock( - account_id="owner-123", tenant_id=tenant_id, role=TenantAccountRole.OWNER - ) - - # Mock dataset permissions query (empty - owner doesn't need explicit permissions) - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [] - mock_dependencies["db_session"].query.return_value = mock_query - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id) - for i in range(3) - ] - mock_paginate_result.total = 3 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets( - page=1, per_page=20, tenant_id=tenant_id, user=user, include_all=True - ) - - # Assert - assert len(datasets) == 3 - assert total == 3 - - def test_get_datasets_normal_user_only_me_permission(self, mock_dependencies): - """Test that normal user sees ONLY_ME datasets they created.""" - # Arrange - tenant_id = str(uuid4()) - user_id = "user-123" - user = DatasetRetrievalTestDataFactory.create_account_mock( - account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL - ) - - # Mock dataset permissions query (no explicit permissions) - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [] - mock_dependencies["db_session"].query.return_value = mock_query - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock( - dataset_id="dataset-1", - tenant_id=tenant_id, - created_by=user_id, - permission=DatasetPermissionEnum.ONLY_ME, - ) - ] - mock_paginate_result.total = 1 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user) - - # Assert - assert len(datasets) == 1 - assert total == 1 - - def test_get_datasets_normal_user_all_team_permission(self, mock_dependencies): - """Test that normal user sees ALL_TEAM datasets.""" - # Arrange - tenant_id = str(uuid4()) - user = DatasetRetrievalTestDataFactory.create_account_mock( - account_id="user-123", tenant_id=tenant_id, role=TenantAccountRole.NORMAL - ) - - # Mock dataset permissions query (no explicit permissions) - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [] - mock_dependencies["db_session"].query.return_value = mock_query - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock( - dataset_id="dataset-1", - tenant_id=tenant_id, - permission=DatasetPermissionEnum.ALL_TEAM, - ) - ] - mock_paginate_result.total = 1 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user) - - # Assert - assert len(datasets) == 1 - assert total == 1 - - def test_get_datasets_normal_user_partial_team_with_permission(self, mock_dependencies): - """Test that normal user sees PARTIAL_TEAM datasets they have permission for.""" - # Arrange - tenant_id = str(uuid4()) - user_id = "user-123" - dataset_id = "dataset-1" - user = DatasetRetrievalTestDataFactory.create_account_mock( - account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL - ) - - # Mock dataset permissions query - user has permission - permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset_id, account_id=user_id - ) - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [permission] - mock_dependencies["db_session"].query.return_value = mock_query - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock( - dataset_id=dataset_id, - tenant_id=tenant_id, - permission=DatasetPermissionEnum.PARTIAL_TEAM, - ) - ] - mock_paginate_result.total = 1 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user) - - # Assert - assert len(datasets) == 1 - assert total == 1 - - def test_get_datasets_dataset_operator_with_permissions(self, mock_dependencies): - """Test that DATASET_OPERATOR only sees datasets they have explicit permission for.""" - # Arrange - tenant_id = str(uuid4()) - user_id = "operator-123" - dataset_id = "dataset-1" - user = DatasetRetrievalTestDataFactory.create_account_mock( - account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR - ) - - # Mock dataset permissions query - operator has permission - permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset_id, account_id=user_id - ) - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [permission] - mock_dependencies["db_session"].query.return_value = mock_query - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id) - ] - mock_paginate_result.total = 1 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user) - - # Assert - assert len(datasets) == 1 - assert total == 1 - - def test_get_datasets_dataset_operator_without_permissions(self, mock_dependencies): - """Test that DATASET_OPERATOR without permissions returns empty result.""" - # Arrange - tenant_id = str(uuid4()) - user_id = "operator-123" - user = DatasetRetrievalTestDataFactory.create_account_mock( - account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR - ) - - # Mock dataset permissions query - no permissions - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [] - mock_dependencies["db_session"].query.return_value = mock_query - - # Act - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user) - - # Assert - assert datasets == [] - assert total == 0 - - -class TestDatasetServiceGetDataset: - """Comprehensive unit tests for DatasetService.get_dataset method.""" - - @pytest.fixture - def mock_dependencies(self): - """Common mock setup for get_dataset tests.""" - with patch("services.dataset_service.db.session") as mock_db: - yield {"db_session": mock_db} - - def test_get_dataset_success(self, mock_dependencies): - """Test successful retrieval of a single dataset.""" - # Arrange - dataset_id = str(uuid4()) - dataset = DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id) - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = dataset - mock_dependencies["db_session"].query.return_value = mock_query - - # Act - result = DatasetService.get_dataset(dataset_id) - - # Assert - assert result is not None - assert result.id == dataset_id - mock_query.filter_by.assert_called_once_with(id=dataset_id) - - def test_get_dataset_not_found(self, mock_dependencies): - """Test retrieval when dataset doesn't exist.""" - # Arrange - dataset_id = str(uuid4()) - - # Mock database query returning None - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dependencies["db_session"].query.return_value = mock_query - - # Act - result = DatasetService.get_dataset(dataset_id) - - # Assert - assert result is None - - -class TestDatasetServiceGetDatasetsByIds: - """Comprehensive unit tests for DatasetService.get_datasets_by_ids method.""" - - @pytest.fixture - def mock_dependencies(self): - """Common mock setup for get_datasets_by_ids tests.""" - with patch("services.dataset_service.db.paginate") as mock_paginate: - yield {"paginate": mock_paginate} - - def test_get_datasets_by_ids_success(self, mock_dependencies): - """Test successful bulk retrieval of datasets by IDs.""" - # Arrange - tenant_id = str(uuid4()) - dataset_ids = [str(uuid4()), str(uuid4()), str(uuid4())] - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id) - for dataset_id in dataset_ids - ] - mock_paginate_result.total = len(dataset_ids) - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id) - - # Assert - assert len(datasets) == 3 - assert total == 3 - assert all(dataset.id in dataset_ids for dataset in datasets) - mock_dependencies["paginate"].assert_called_once() - - def test_get_datasets_by_ids_empty_list(self, mock_dependencies): - """Test get_datasets_by_ids with empty list returns empty result.""" - # Arrange - tenant_id = str(uuid4()) - dataset_ids = [] - - # Act - datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id) - - # Assert - assert datasets == [] - assert total == 0 - mock_dependencies["paginate"].assert_not_called() - - def test_get_datasets_by_ids_none_list(self, mock_dependencies): - """Test get_datasets_by_ids with None returns empty result.""" - # Arrange - tenant_id = str(uuid4()) - - # Act - datasets, total = DatasetService.get_datasets_by_ids(None, tenant_id) - - # Assert - assert datasets == [] - assert total == 0 - mock_dependencies["paginate"].assert_not_called() - - -class TestDatasetServiceGetProcessRules: - """Comprehensive unit tests for DatasetService.get_process_rules method.""" - - @pytest.fixture - def mock_dependencies(self): - """Common mock setup for get_process_rules tests.""" - with patch("services.dataset_service.db.session") as mock_db: - yield {"db_session": mock_db} - - def test_get_process_rules_with_existing_rule(self, mock_dependencies): - """Test retrieval of process rules when rule exists.""" - # Arrange - dataset_id = str(uuid4()) - rules_data = { - "pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}], - "segmentation": {"delimiter": "\n", "max_tokens": 500}, - } - process_rule = DatasetRetrievalTestDataFactory.create_process_rule_mock( - dataset_id=dataset_id, mode="custom", rules=rules_data - ) - - # Mock database query - mock_query = Mock() - mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = process_rule - mock_dependencies["db_session"].query.return_value = mock_query - - # Act - result = DatasetService.get_process_rules(dataset_id) - - # Assert - assert result["mode"] == "custom" - assert result["rules"] == rules_data - - def test_get_process_rules_without_existing_rule(self, mock_dependencies): - """Test retrieval of process rules when no rule exists (returns defaults).""" - # Arrange - dataset_id = str(uuid4()) - - # Mock database query returning None - mock_query = Mock() - mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = None - mock_dependencies["db_session"].query.return_value = mock_query - - # Act - result = DatasetService.get_process_rules(dataset_id) - - # Assert - assert result["mode"] == DocumentService.DEFAULT_RULES["mode"] - assert "rules" in result - assert result["rules"] == DocumentService.DEFAULT_RULES["rules"] - - -class TestDatasetServiceGetDatasetQueries: - """Comprehensive unit tests for DatasetService.get_dataset_queries method.""" - - @pytest.fixture - def mock_dependencies(self): - """Common mock setup for get_dataset_queries tests.""" - with patch("services.dataset_service.db.paginate") as mock_paginate: - yield {"paginate": mock_paginate} - - def test_get_dataset_queries_success(self, mock_dependencies): - """Test successful retrieval of dataset queries.""" - # Arrange - dataset_id = str(uuid4()) - page = 1 - per_page = 20 - - # Mock pagination result - mock_paginate_result = Mock() - mock_paginate_result.items = [ - DatasetRetrievalTestDataFactory.create_dataset_query_mock(dataset_id=dataset_id, query_id=f"query-{i}") - for i in range(3) - ] - mock_paginate_result.total = 3 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page) - - # Assert - assert len(queries) == 3 - assert total == 3 - assert all(query.dataset_id == dataset_id for query in queries) - mock_dependencies["paginate"].assert_called_once() - - def test_get_dataset_queries_empty_result(self, mock_dependencies): - """Test retrieval when no queries exist.""" - # Arrange - dataset_id = str(uuid4()) - page = 1 - per_page = 20 - - # Mock pagination result (empty) - mock_paginate_result = Mock() - mock_paginate_result.items = [] - mock_paginate_result.total = 0 - mock_dependencies["paginate"].return_value = mock_paginate_result - - # Act - queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page) - - # Assert - assert queries == [] - assert total == 0 - - -class TestDatasetServiceGetRelatedApps: - """Comprehensive unit tests for DatasetService.get_related_apps method.""" - - @pytest.fixture - def mock_dependencies(self): - """Common mock setup for get_related_apps tests.""" - with patch("services.dataset_service.db.session") as mock_db: - yield {"db_session": mock_db} - - def test_get_related_apps_success(self, mock_dependencies): - """Test successful retrieval of related apps.""" - # Arrange - dataset_id = str(uuid4()) - - # Mock app-dataset joins - app_joins = [ - DatasetRetrievalTestDataFactory.create_app_dataset_join_mock(app_id=f"app-{i}", dataset_id=dataset_id) - for i in range(2) - ] - - # Mock database query - mock_query = Mock() - mock_query.where.return_value.order_by.return_value.all.return_value = app_joins - mock_dependencies["db_session"].query.return_value = mock_query - - # Act - result = DatasetService.get_related_apps(dataset_id) - - # Assert - assert len(result) == 2 - assert all(join.dataset_id == dataset_id for join in result) - mock_query.where.assert_called_once() - mock_query.where.return_value.order_by.assert_called_once() - - def test_get_related_apps_empty_result(self, mock_dependencies): - """Test retrieval when no related apps exist.""" - # Arrange - dataset_id = str(uuid4()) - - # Mock database query returning empty list - mock_query = Mock() - mock_query.where.return_value.order_by.return_value.all.return_value = [] - mock_dependencies["db_session"].query.return_value = mock_query - - # Act - result = DatasetService.get_related_apps(dataset_id) - - # Assert - assert result == []