diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 2cea24d085..8c8be2e670 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -295,9 +295,13 @@ class TestAPIBasedExtensionService: original_name = created_extension.name original_endpoint = created_extension.api_endpoint - # Update the extension + # Update the extension with guaranteed different values new_name = fake.company() + # Ensure new endpoint is different from original new_endpoint = f"https://{fake.domain_name()}/api" + # If by chance they're the same, generate a new one + while new_endpoint == original_endpoint: + new_endpoint = f"https://{fake.domain_name()}/api" new_api_key = fake.password(length=20) created_extension.name = new_name diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py new file mode 100644 index 0000000000..ee05e890b2 --- /dev/null +++ b/api/tests/unit_tests/services/segment_service.py @@ -0,0 +1,1093 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from models.account import Account +from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from services.dataset_service import SegmentService +from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError + + +class SegmentTestDataFactory: + """Factory class for creating test data and mock objects for segment service tests.""" + + @staticmethod + def create_segment_mock( + segment_id: str = "segment-123", + document_id: str = "doc-123", + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + content: str = "Test segment content", + position: int = 1, + enabled: bool = True, + status: str = "completed", + word_count: int = 3, + tokens: int = 5, + **kwargs, + ) -> Mock: + """Create a mock segment with specified attributes.""" + segment = Mock(spec=DocumentSegment) + segment.id = segment_id + segment.document_id = document_id + segment.dataset_id = dataset_id + segment.tenant_id = tenant_id + segment.content = content + segment.position = position + segment.enabled = enabled + segment.status = status + segment.word_count = word_count + segment.tokens = tokens + segment.index_node_id = f"node-{segment_id}" + segment.index_node_hash = "hash-123" + segment.keywords = [] + segment.answer = None + segment.disabled_at = None + segment.disabled_by = None + segment.updated_by = None + segment.updated_at = None + segment.indexing_at = None + segment.completed_at = None + segment.error = None + for key, value in kwargs.items(): + setattr(segment, key, value) + return segment + + @staticmethod + def create_child_chunk_mock( + chunk_id: str = "chunk-123", + segment_id: str = "segment-123", + document_id: str = "doc-123", + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + content: str = "Test child chunk content", + position: int = 1, + word_count: int = 3, + **kwargs, + ) -> Mock: + """Create a mock child chunk with specified attributes.""" + chunk = Mock(spec=ChildChunk) + chunk.id = chunk_id + chunk.segment_id = segment_id + chunk.document_id = document_id + chunk.dataset_id = dataset_id + chunk.tenant_id = tenant_id + chunk.content = content + chunk.position = position + chunk.word_count = word_count + chunk.index_node_id = f"node-{chunk_id}" + chunk.index_node_hash = "hash-123" + chunk.type = "automatic" + chunk.created_by = "user-123" + chunk.updated_by = None + chunk.updated_at = None + for key, value in kwargs.items(): + setattr(chunk, key, value) + return chunk + + @staticmethod + def create_document_mock( + document_id: str = "doc-123", + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + doc_form: str = "text_model", + word_count: int = 100, + **kwargs, + ) -> Mock: + """Create a mock document with specified attributes.""" + document = Mock(spec=Document) + document.id = document_id + document.dataset_id = dataset_id + document.tenant_id = tenant_id + document.doc_form = doc_form + document.word_count = word_count + for key, value in kwargs.items(): + setattr(document, key, value) + return document + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + indexing_technique: str = "high_quality", + embedding_model: str = "text-embedding-ada-002", + embedding_model_provider: str = "openai", + **kwargs, + ) -> Mock: + """Create a mock dataset with specified attributes.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = indexing_technique + dataset.embedding_model = embedding_model + dataset.embedding_model_provider = embedding_model_provider + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_user_mock( + user_id: str = "user-789", + tenant_id: str = "tenant-123", + **kwargs, + ) -> Mock: + """Create a mock user with specified attributes.""" + user = Mock(spec=Account) + user.id = user_id + user.current_tenant_id = tenant_id + user.name = "Test User" + for key, value in kwargs.items(): + setattr(user, key, value) + return user + + +class TestSegmentServiceCreateSegment: + """Tests for SegmentService.create_segment method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + @pytest.fixture + def mock_current_user(self): + """Mock current_user.""" + user = SegmentTestDataFactory.create_user_mock() + with patch("services.dataset_service.current_user", user): + yield user + + def test_create_segment_success(self, mock_db_session, mock_current_user): + """Test successful creation of a segment.""" + # Arrange + document = SegmentTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + args = {"content": "New segment content", "keywords": ["test", "segment"]} + + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = None # No existing segments + mock_db_session.query.return_value = mock_query + + mock_segment = SegmentTestDataFactory.create_segment_mock() + mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment + + with ( + patch("services.dataset_service.redis_client.lock") as mock_lock, + patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_lock.return_value.__enter__ = Mock() + mock_lock.return_value.__exit__ = Mock(return_value=None) + mock_hash.return_value = "hash-123" + mock_now.return_value = "2024-01-01T00:00:00" + + # Act + result = SegmentService.create_segment(args, document, dataset) + + # Assert + assert mock_db_session.add.call_count == 2 + + created_segment = mock_db_session.add.call_args_list[0].args[0] + assert isinstance(created_segment, DocumentSegment) + assert created_segment.content == args["content"] + assert created_segment.word_count == len(args["content"]) + + mock_db_session.commit.assert_called_once() + + mock_vector_service.assert_called_once() + vector_call_args = mock_vector_service.call_args[0] + assert vector_call_args[0] == [args["keywords"]] + assert vector_call_args[1][0] == created_segment + assert vector_call_args[2] == dataset + assert vector_call_args[3] == document.doc_form + + assert result == mock_segment + + def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user): + """Test creation of segment with QA model (requires answer).""" + # Arrange + document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} + + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = None + mock_db_session.query.return_value = mock_query + + mock_segment = SegmentTestDataFactory.create_segment_mock() + mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment + + with ( + patch("services.dataset_service.redis_client.lock") as mock_lock, + patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_lock.return_value.__enter__ = Mock() + mock_lock.return_value.__exit__ = Mock(return_value=None) + mock_hash.return_value = "hash-123" + mock_now.return_value = "2024-01-01T00:00:00" + + # Act + result = SegmentService.create_segment(args, document, dataset) + + # Assert + assert result == mock_segment + mock_db_session.add.assert_called() + mock_db_session.commit.assert_called() + + def test_create_segment_with_high_quality_indexing(self, mock_db_session, mock_current_user): + """Test creation of segment with high quality indexing technique.""" + # Arrange + document = SegmentTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + args = {"content": "New segment content", "keywords": ["test"]} + + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = None + mock_db_session.query.return_value = mock_query + + mock_embedding_model = MagicMock() + mock_embedding_model.get_text_embedding_num_tokens.return_value = [10] + mock_model_manager = MagicMock() + mock_model_manager.get_model_instance.return_value = mock_embedding_model + + mock_segment = SegmentTestDataFactory.create_segment_mock() + mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment + + with ( + patch("services.dataset_service.redis_client.lock") as mock_lock, + patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service, + patch("services.dataset_service.ModelManager") as mock_model_manager_class, + patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_lock.return_value.__enter__ = Mock() + mock_lock.return_value.__exit__ = Mock(return_value=None) + mock_model_manager_class.return_value = mock_model_manager + mock_hash.return_value = "hash-123" + mock_now.return_value = "2024-01-01T00:00:00" + + # Act + result = SegmentService.create_segment(args, document, dataset) + + # Assert + assert result == mock_segment + mock_model_manager.get_model_instance.assert_called_once() + mock_embedding_model.get_text_embedding_num_tokens.assert_called_once() + + def test_create_segment_vector_index_failure(self, mock_db_session, mock_current_user): + """Test segment creation when vector indexing fails.""" + # Arrange + document = SegmentTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + args = {"content": "New segment content", "keywords": ["test"]} + + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = None + mock_db_session.query.return_value = mock_query + + mock_segment = SegmentTestDataFactory.create_segment_mock(enabled=False, status="error") + mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment + + with ( + patch("services.dataset_service.redis_client.lock") as mock_lock, + patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_lock.return_value.__enter__ = Mock() + mock_lock.return_value.__exit__ = Mock(return_value=None) + mock_vector_service.side_effect = Exception("Vector indexing failed") + mock_hash.return_value = "hash-123" + mock_now.return_value = "2024-01-01T00:00:00" + + # Act + result = SegmentService.create_segment(args, document, dataset) + + # Assert + assert result == mock_segment + assert mock_db_session.commit.call_count == 2 # Once for creation, once for error update + + +class TestSegmentServiceUpdateSegment: + """Tests for SegmentService.update_segment method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + @pytest.fixture + def mock_current_user(self): + """Mock current_user.""" + user = SegmentTestDataFactory.create_user_mock() + with patch("services.dataset_service.current_user", user): + yield user + + def test_update_segment_content_success(self, mock_db_session, mock_current_user): + """Test successful update of segment content.""" + # Arrange + segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) + document = SegmentTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + args = SegmentUpdateArgs(content="Updated content", keywords=["updated"]) + + mock_db_session.query.return_value.where.return_value.first.return_value = segment + + with ( + patch("services.dataset_service.redis_client.get") as mock_redis_get, + patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_redis_get.return_value = None # Not indexing + mock_hash.return_value = "new-hash" + mock_now.return_value = "2024-01-01T00:00:00" + + # Act + result = SegmentService.update_segment(args, segment, document, dataset) + + # Assert + assert result == segment + assert segment.content == "Updated content" + assert segment.keywords == ["updated"] + assert segment.word_count == len("Updated content") + assert document.word_count == 100 + (len("Updated content") - 10) + mock_db_session.add.assert_called() + mock_db_session.commit.assert_called() + + def test_update_segment_disable(self, mock_db_session, mock_current_user): + """Test disabling a segment.""" + # Arrange + segment = SegmentTestDataFactory.create_segment_mock(enabled=True) + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + args = SegmentUpdateArgs(enabled=False) + + with ( + patch("services.dataset_service.redis_client.get") as mock_redis_get, + patch("services.dataset_service.redis_client.setex") as mock_redis_setex, + patch("services.dataset_service.disable_segment_from_index_task") as mock_task, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_redis_get.return_value = None + mock_now.return_value = "2024-01-01T00:00:00" + + # Act + result = SegmentService.update_segment(args, segment, document, dataset) + + # Assert + assert result == segment + assert segment.enabled is False + mock_db_session.add.assert_called() + mock_db_session.commit.assert_called() + mock_task.delay.assert_called_once() + + def test_update_segment_indexing_in_progress(self, mock_db_session, mock_current_user): + """Test update fails when segment is currently indexing.""" + # Arrange + segment = SegmentTestDataFactory.create_segment_mock(enabled=True) + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + args = SegmentUpdateArgs(content="Updated content") + + with patch("services.dataset_service.redis_client.get") as mock_redis_get: + mock_redis_get.return_value = "1" # Indexing in progress + + # Act & Assert + with pytest.raises(ValueError, match="Segment is indexing"): + SegmentService.update_segment(args, segment, document, dataset) + + def test_update_segment_disabled_segment(self, mock_db_session, mock_current_user): + """Test update fails when segment is disabled.""" + # Arrange + segment = SegmentTestDataFactory.create_segment_mock(enabled=False) + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + args = SegmentUpdateArgs(content="Updated content") + + with patch("services.dataset_service.redis_client.get") as mock_redis_get: + mock_redis_get.return_value = None + + # Act & Assert + with pytest.raises(ValueError, match="Can't update disabled segment"): + SegmentService.update_segment(args, segment, document, dataset) + + def test_update_segment_with_qa_model(self, mock_db_session, mock_current_user): + """Test update segment with QA model (includes answer).""" + # Arrange + segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) + document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) + + mock_db_session.query.return_value.where.return_value.first.return_value = segment + + with ( + patch("services.dataset_service.redis_client.get") as mock_redis_get, + patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_redis_get.return_value = None + mock_hash.return_value = "new-hash" + mock_now.return_value = "2024-01-01T00:00:00" + + # Act + result = SegmentService.update_segment(args, segment, document, dataset) + + # Assert + assert result == segment + assert segment.content == "Updated question" + assert segment.answer == "Updated answer" + assert segment.keywords == ["qa"] + new_word_count = len("Updated question") + len("Updated answer") + assert segment.word_count == new_word_count + assert document.word_count == 100 + (new_word_count - 10) + mock_db_session.commit.assert_called() + + +class TestSegmentServiceDeleteSegment: + """Tests for SegmentService.delete_segment method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + def test_delete_segment_success(self, mock_db_session): + """Test successful deletion of a segment.""" + # Arrange + segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=50) + document = SegmentTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock() + + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_db_session.scalars.return_value = mock_scalars + + with ( + patch("services.dataset_service.redis_client.get") as mock_redis_get, + patch("services.dataset_service.redis_client.setex") as mock_redis_setex, + patch("services.dataset_service.delete_segment_from_index_task") as mock_task, + patch("services.dataset_service.select") as mock_select, + ): + mock_redis_get.return_value = None + mock_select.return_value.where.return_value = mock_select + + # Act + SegmentService.delete_segment(segment, document, dataset) + + # Assert + mock_db_session.delete.assert_called_once_with(segment) + mock_db_session.commit.assert_called_once() + mock_task.delay.assert_called_once() + + def test_delete_segment_disabled(self, mock_db_session): + """Test deletion of disabled segment (no index deletion).""" + # Arrange + segment = SegmentTestDataFactory.create_segment_mock(enabled=False, word_count=50) + document = SegmentTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentTestDataFactory.create_dataset_mock() + + with ( + patch("services.dataset_service.redis_client.get") as mock_redis_get, + patch("services.dataset_service.delete_segment_from_index_task") as mock_task, + ): + mock_redis_get.return_value = None + + # Act + SegmentService.delete_segment(segment, document, dataset) + + # Assert + mock_db_session.delete.assert_called_once_with(segment) + mock_db_session.commit.assert_called_once() + mock_task.delay.assert_not_called() + + def test_delete_segment_indexing_in_progress(self, mock_db_session): + """Test deletion fails when segment is currently being deleted.""" + # Arrange + segment = SegmentTestDataFactory.create_segment_mock(enabled=True) + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + with patch("services.dataset_service.redis_client.get") as mock_redis_get: + mock_redis_get.return_value = "1" # Deletion in progress + + # Act & Assert + with pytest.raises(ValueError, match="Segment is deleting"): + SegmentService.delete_segment(segment, document, dataset) + + +class TestSegmentServiceDeleteSegments: + """Tests for SegmentService.delete_segments method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + @pytest.fixture + def mock_current_user(self): + """Mock current_user.""" + user = SegmentTestDataFactory.create_user_mock() + with patch("services.dataset_service.current_user", user): + yield user + + def test_delete_segments_success(self, mock_db_session, mock_current_user): + """Test successful deletion of multiple segments.""" + # Arrange + segment_ids = ["segment-1", "segment-2"] + document = SegmentTestDataFactory.create_document_mock(word_count=200) + dataset = SegmentTestDataFactory.create_dataset_mock() + + segments_info = [ + ("node-1", "segment-1", 50), + ("node-2", "segment-2", 30), + ] + + mock_query = MagicMock() + mock_query.with_entities.return_value.where.return_value.all.return_value = segments_info + mock_db_session.query.return_value = mock_query + + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_select = MagicMock() + mock_select.where.return_value = mock_select + mock_db_session.scalars.return_value = mock_scalars + + with ( + patch("services.dataset_service.delete_segment_from_index_task") as mock_task, + patch("services.dataset_service.select") as mock_select_func, + ): + mock_select_func.return_value = mock_select + + # Act + SegmentService.delete_segments(segment_ids, document, dataset) + + # Assert + mock_db_session.query.return_value.where.return_value.delete.assert_called_once() + mock_db_session.commit.assert_called_once() + mock_task.delay.assert_called_once() + + def test_delete_segments_empty_list(self, mock_db_session, mock_current_user): + """Test deletion with empty list (should return early).""" + # Arrange + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + # Act + SegmentService.delete_segments([], document, dataset) + + # Assert + mock_db_session.query.assert_not_called() + + +class TestSegmentServiceUpdateSegmentsStatus: + """Tests for SegmentService.update_segments_status method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + @pytest.fixture + def mock_current_user(self): + """Mock current_user.""" + user = SegmentTestDataFactory.create_user_mock() + with patch("services.dataset_service.current_user", user): + yield user + + def test_update_segments_status_enable(self, mock_db_session, mock_current_user): + """Test enabling multiple segments.""" + # Arrange + segment_ids = ["segment-1", "segment-2"] + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + segments = [ + SegmentTestDataFactory.create_segment_mock(segment_id="segment-1", enabled=False), + SegmentTestDataFactory.create_segment_mock(segment_id="segment-2", enabled=False), + ] + + mock_scalars = MagicMock() + mock_scalars.all.return_value = segments + mock_select = MagicMock() + mock_select.where.return_value = mock_select + mock_db_session.scalars.return_value = mock_scalars + + with ( + patch("services.dataset_service.redis_client.get") as mock_redis_get, + patch("services.dataset_service.enable_segments_to_index_task") as mock_task, + patch("services.dataset_service.select") as mock_select_func, + ): + mock_redis_get.return_value = None + mock_select_func.return_value = mock_select + + # Act + SegmentService.update_segments_status(segment_ids, "enable", dataset, document) + + # Assert + assert all(seg.enabled is True for seg in segments) + mock_db_session.commit.assert_called_once() + mock_task.delay.assert_called_once() + + def test_update_segments_status_disable(self, mock_db_session, mock_current_user): + """Test disabling multiple segments.""" + # Arrange + segment_ids = ["segment-1", "segment-2"] + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + segments = [ + SegmentTestDataFactory.create_segment_mock(segment_id="segment-1", enabled=True), + SegmentTestDataFactory.create_segment_mock(segment_id="segment-2", enabled=True), + ] + + mock_scalars = MagicMock() + mock_scalars.all.return_value = segments + mock_select = MagicMock() + mock_select.where.return_value = mock_select + mock_db_session.scalars.return_value = mock_scalars + + with ( + patch("services.dataset_service.redis_client.get") as mock_redis_get, + patch("services.dataset_service.disable_segments_from_index_task") as mock_task, + patch("services.dataset_service.naive_utc_now") as mock_now, + patch("services.dataset_service.select") as mock_select_func, + ): + mock_redis_get.return_value = None + mock_now.return_value = "2024-01-01T00:00:00" + mock_select_func.return_value = mock_select + + # Act + SegmentService.update_segments_status(segment_ids, "disable", dataset, document) + + # Assert + assert all(seg.enabled is False for seg in segments) + mock_db_session.commit.assert_called_once() + mock_task.delay.assert_called_once() + + def test_update_segments_status_empty_list(self, mock_db_session, mock_current_user): + """Test update with empty list (should return early).""" + # Arrange + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + # Act + SegmentService.update_segments_status([], "enable", dataset, document) + + # Assert + mock_db_session.scalars.assert_not_called() + + +class TestSegmentServiceGetSegments: + """Tests for SegmentService.get_segments method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + @pytest.fixture + def mock_current_user(self): + """Mock current_user.""" + user = SegmentTestDataFactory.create_user_mock() + with patch("services.dataset_service.current_user", user): + yield user + + def test_get_segments_success(self, mock_db_session, mock_current_user): + """Test successful retrieval of segments.""" + # Arrange + document_id = "doc-123" + tenant_id = "tenant-123" + segments = [ + SegmentTestDataFactory.create_segment_mock(segment_id="segment-1"), + SegmentTestDataFactory.create_segment_mock(segment_id="segment-2"), + ] + + mock_paginate = MagicMock() + mock_paginate.items = segments + mock_paginate.total = 2 + mock_db_session.paginate.return_value = mock_paginate + + # Act + items, total = SegmentService.get_segments(document_id, tenant_id) + + # Assert + assert len(items) == 2 + assert total == 2 + mock_db_session.paginate.assert_called_once() + + def test_get_segments_with_status_filter(self, mock_db_session, mock_current_user): + """Test retrieval with status filter.""" + # Arrange + document_id = "doc-123" + tenant_id = "tenant-123" + status_list = ["completed", "error"] + + mock_paginate = MagicMock() + mock_paginate.items = [] + mock_paginate.total = 0 + mock_db_session.paginate.return_value = mock_paginate + + # Act + items, total = SegmentService.get_segments(document_id, tenant_id, status_list=status_list) + + # Assert + assert len(items) == 0 + assert total == 0 + + def test_get_segments_with_keyword(self, mock_db_session, mock_current_user): + """Test retrieval with keyword search.""" + # Arrange + document_id = "doc-123" + tenant_id = "tenant-123" + keyword = "test" + + mock_paginate = MagicMock() + mock_paginate.items = [SegmentTestDataFactory.create_segment_mock()] + mock_paginate.total = 1 + mock_db_session.paginate.return_value = mock_paginate + + # Act + items, total = SegmentService.get_segments(document_id, tenant_id, keyword=keyword) + + # Assert + assert len(items) == 1 + assert total == 1 + + +class TestSegmentServiceGetSegmentById: + """Tests for SegmentService.get_segment_by_id method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + def test_get_segment_by_id_success(self, mock_db_session): + """Test successful retrieval of segment by ID.""" + # Arrange + segment_id = "segment-123" + tenant_id = "tenant-123" + segment = SegmentTestDataFactory.create_segment_mock(segment_id=segment_id) + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = segment + mock_db_session.query.return_value = mock_query + + # Act + result = SegmentService.get_segment_by_id(segment_id, tenant_id) + + # Assert + assert result == segment + + def test_get_segment_by_id_not_found(self, mock_db_session): + """Test retrieval when segment is not found.""" + # Arrange + segment_id = "non-existent" + tenant_id = "tenant-123" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = None + mock_db_session.query.return_value = mock_query + + # Act + result = SegmentService.get_segment_by_id(segment_id, tenant_id) + + # Assert + assert result is None + + +class TestSegmentServiceGetChildChunks: + """Tests for SegmentService.get_child_chunks method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + @pytest.fixture + def mock_current_user(self): + """Mock current_user.""" + user = SegmentTestDataFactory.create_user_mock() + with patch("services.dataset_service.current_user", user): + yield user + + def test_get_child_chunks_success(self, mock_db_session, mock_current_user): + """Test successful retrieval of child chunks.""" + # Arrange + segment_id = "segment-123" + document_id = "doc-123" + dataset_id = "dataset-123" + page = 1 + limit = 20 + + mock_paginate = MagicMock() + mock_paginate.items = [ + SegmentTestDataFactory.create_child_chunk_mock(chunk_id="chunk-1"), + SegmentTestDataFactory.create_child_chunk_mock(chunk_id="chunk-2"), + ] + mock_paginate.total = 2 + mock_db_session.paginate.return_value = mock_paginate + + # Act + result = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit) + + # Assert + assert result == mock_paginate + mock_db_session.paginate.assert_called_once() + + def test_get_child_chunks_with_keyword(self, mock_db_session, mock_current_user): + """Test retrieval with keyword search.""" + # Arrange + segment_id = "segment-123" + document_id = "doc-123" + dataset_id = "dataset-123" + page = 1 + limit = 20 + keyword = "test" + + mock_paginate = MagicMock() + mock_paginate.items = [] + mock_paginate.total = 0 + mock_db_session.paginate.return_value = mock_paginate + + # Act + result = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword=keyword) + + # Assert + assert result == mock_paginate + + +class TestSegmentServiceGetChildChunkById: + """Tests for SegmentService.get_child_chunk_by_id method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + def test_get_child_chunk_by_id_success(self, mock_db_session): + """Test successful retrieval of child chunk by ID.""" + # Arrange + chunk_id = "chunk-123" + tenant_id = "tenant-123" + chunk = SegmentTestDataFactory.create_child_chunk_mock(chunk_id=chunk_id) + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = chunk + mock_db_session.query.return_value = mock_query + + # Act + result = SegmentService.get_child_chunk_by_id(chunk_id, tenant_id) + + # Assert + assert result == chunk + + def test_get_child_chunk_by_id_not_found(self, mock_db_session): + """Test retrieval when child chunk is not found.""" + # Arrange + chunk_id = "non-existent" + tenant_id = "tenant-123" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = None + mock_db_session.query.return_value = mock_query + + # Act + result = SegmentService.get_child_chunk_by_id(chunk_id, tenant_id) + + # Assert + assert result is None + + +class TestSegmentServiceCreateChildChunk: + """Tests for SegmentService.create_child_chunk method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + @pytest.fixture + def mock_current_user(self): + """Mock current_user.""" + user = SegmentTestDataFactory.create_user_mock() + with patch("services.dataset_service.current_user", user): + yield user + + def test_create_child_chunk_success(self, mock_db_session, mock_current_user): + """Test successful creation of a child chunk.""" + # Arrange + content = "New child chunk content" + segment = SegmentTestDataFactory.create_segment_mock() + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = None + mock_db_session.query.return_value = mock_query + + with ( + patch("services.dataset_service.redis_client.lock") as mock_lock, + patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + ): + mock_lock.return_value.__enter__ = Mock() + mock_lock.return_value.__exit__ = Mock(return_value=None) + mock_hash.return_value = "hash-123" + + # Act + result = SegmentService.create_child_chunk(content, segment, document, dataset) + + # Assert + assert result is not None + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + mock_vector_service.assert_called_once() + + def test_create_child_chunk_vector_index_failure(self, mock_db_session, mock_current_user): + """Test child chunk creation when vector indexing fails.""" + # Arrange + content = "New child chunk content" + segment = SegmentTestDataFactory.create_segment_mock() + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = None + mock_db_session.query.return_value = mock_query + + with ( + patch("services.dataset_service.redis_client.lock") as mock_lock, + patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + ): + mock_lock.return_value.__enter__ = Mock() + mock_lock.return_value.__exit__ = Mock(return_value=None) + mock_vector_service.side_effect = Exception("Vector indexing failed") + mock_hash.return_value = "hash-123" + + # Act & Assert + with pytest.raises(ChildChunkIndexingError): + SegmentService.create_child_chunk(content, segment, document, dataset) + + mock_db_session.rollback.assert_called_once() + + +class TestSegmentServiceUpdateChildChunk: + """Tests for SegmentService.update_child_chunk method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + @pytest.fixture + def mock_current_user(self): + """Mock current_user.""" + user = SegmentTestDataFactory.create_user_mock() + with patch("services.dataset_service.current_user", user): + yield user + + def test_update_child_chunk_success(self, mock_db_session, mock_current_user): + """Test successful update of a child chunk.""" + # Arrange + content = "Updated child chunk content" + chunk = SegmentTestDataFactory.create_child_chunk_mock() + segment = SegmentTestDataFactory.create_segment_mock() + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + with ( + patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_now.return_value = "2024-01-01T00:00:00" + + # Act + result = SegmentService.update_child_chunk(content, chunk, segment, document, dataset) + + # Assert + assert result == chunk + assert chunk.content == content + assert chunk.word_count == len(content) + mock_db_session.add.assert_called_once_with(chunk) + mock_db_session.commit.assert_called_once() + mock_vector_service.assert_called_once() + + def test_update_child_chunk_vector_index_failure(self, mock_db_session, mock_current_user): + """Test child chunk update when vector indexing fails.""" + # Arrange + content = "Updated content" + chunk = SegmentTestDataFactory.create_child_chunk_mock() + segment = SegmentTestDataFactory.create_segment_mock() + document = SegmentTestDataFactory.create_document_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + with ( + patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service, + patch("services.dataset_service.naive_utc_now") as mock_now, + ): + mock_vector_service.side_effect = Exception("Vector indexing failed") + mock_now.return_value = "2024-01-01T00:00:00" + + # Act & Assert + with pytest.raises(ChildChunkIndexingError): + SegmentService.update_child_chunk(content, chunk, segment, document, dataset) + + mock_db_session.rollback.assert_called_once() + + +class TestSegmentServiceDeleteChildChunk: + """Tests for SegmentService.delete_child_chunk method.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.dataset_service.db.session") as mock_db: + yield mock_db + + def test_delete_child_chunk_success(self, mock_db_session): + """Test successful deletion of a child chunk.""" + # Arrange + chunk = SegmentTestDataFactory.create_child_chunk_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service: + # Act + SegmentService.delete_child_chunk(chunk, dataset) + + # Assert + mock_db_session.delete.assert_called_once_with(chunk) + mock_db_session.commit.assert_called_once() + mock_vector_service.assert_called_once_with(chunk, dataset) + + def test_delete_child_chunk_vector_index_failure(self, mock_db_session): + """Test child chunk deletion when vector indexing fails.""" + # Arrange + chunk = SegmentTestDataFactory.create_child_chunk_mock() + dataset = SegmentTestDataFactory.create_dataset_mock() + + with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service: + mock_vector_service.side_effect = Exception("Vector deletion failed") + + # Act & Assert + with pytest.raises(ChildChunkDeleteIndexError): + SegmentService.delete_child_chunk(chunk, dataset) + + mock_db_session.rollback.assert_called_once()