diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py new file mode 100644 index 0000000000..2322c556e2 --- /dev/null +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -0,0 +1,1341 @@ +""" +Comprehensive unit tests for Dataset models. + +This test suite covers: +- Dataset model validation +- Document model relationships +- Segment model indexing +- Dataset-Document cascade deletes +- Embedding storage validation +""" + +import json +import pickle +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +from models.dataset import ( + AppDatasetJoin, + ChildChunk, + Dataset, + DatasetKeywordTable, + DatasetProcessRule, + Document, + DocumentSegment, + Embedding, +) + + +class TestDatasetModelValidation: + """Test suite for Dataset model validation and basic operations.""" + + def test_dataset_creation_with_required_fields(self): + """Test creating a dataset with all required fields.""" + # Arrange + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + # Act + dataset = Dataset( + tenant_id=tenant_id, + name="Test Dataset", + data_source_type="upload_file", + created_by=created_by, + ) + + # Assert + assert dataset.name == "Test Dataset" + assert dataset.tenant_id == tenant_id + assert dataset.data_source_type == "upload_file" + assert dataset.created_by == created_by + # Note: Default values are set by database, not by model instantiation + + def test_dataset_creation_with_optional_fields(self): + """Test creating a dataset with optional fields.""" + # Arrange & Act + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + description="Test description", + indexing_technique="high_quality", + embedding_model="text-embedding-ada-002", + embedding_model_provider="openai", + ) + + # Assert + assert dataset.description == "Test description" + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "text-embedding-ada-002" + assert dataset.embedding_model_provider == "openai" + + def test_dataset_indexing_technique_validation(self): + """Test dataset indexing technique values.""" + # Arrange & Act + dataset_high_quality = Dataset( + tenant_id=str(uuid4()), + name="High Quality Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + indexing_technique="high_quality", + ) + dataset_economy = Dataset( + tenant_id=str(uuid4()), + name="Economy Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + indexing_technique="economy", + ) + + # Assert + assert dataset_high_quality.indexing_technique == "high_quality" + assert dataset_economy.indexing_technique == "economy" + assert "high_quality" in Dataset.INDEXING_TECHNIQUE_LIST + assert "economy" in Dataset.INDEXING_TECHNIQUE_LIST + + def test_dataset_provider_validation(self): + """Test dataset provider values.""" + # Arrange & Act + dataset_vendor = Dataset( + tenant_id=str(uuid4()), + name="Vendor Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + provider="vendor", + ) + dataset_external = Dataset( + tenant_id=str(uuid4()), + name="External Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + provider="external", + ) + + # Assert + assert dataset_vendor.provider == "vendor" + assert dataset_external.provider == "external" + assert "vendor" in Dataset.PROVIDER_LIST + assert "external" in Dataset.PROVIDER_LIST + + def test_dataset_index_struct_dict_property(self): + """Test index_struct_dict property parsing.""" + # Arrange + index_struct_data = {"type": "vector", "dimension": 1536} + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + index_struct=json.dumps(index_struct_data), + ) + + # Act + result = dataset.index_struct_dict + + # Assert + assert result == index_struct_data + assert result["type"] == "vector" + assert result["dimension"] == 1536 + + def test_dataset_index_struct_dict_property_none(self): + """Test index_struct_dict property when index_struct is None.""" + # Arrange + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + + # Act + result = dataset.index_struct_dict + + # Assert + assert result is None + + def test_dataset_external_retrieval_model_property(self): + """Test external_retrieval_model property with default values.""" + # Arrange + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + + # Act + result = dataset.external_retrieval_model + + # Assert + assert result["top_k"] == 2 + assert result["score_threshold"] == 0.0 + + def test_dataset_retrieval_model_dict_property(self): + """Test retrieval_model_dict property with default values.""" + # Arrange + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + + # Act + result = dataset.retrieval_model_dict + + # Assert + assert result["top_k"] == 2 + assert result["reranking_enable"] is False + assert result["score_threshold_enabled"] is False + + def test_dataset_gen_collection_name_by_id(self): + """Test static method for generating collection name.""" + # Arrange + dataset_id = "12345678-1234-1234-1234-123456789abc" + + # Act + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + + # Assert + assert "12345678_1234_1234_1234_123456789abc" in collection_name + assert "-" not in collection_name.split("_")[-1] + + +class TestDocumentModelRelationships: + """Test suite for Document model relationships and properties.""" + + def test_document_creation_with_required_fields(self): + """Test creating a document with all required fields.""" + # Arrange + tenant_id = str(uuid4()) + dataset_id = str(uuid4()) + created_by = str(uuid4()) + + # Act + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test_document.pdf", + created_from="web", + created_by=created_by, + ) + + # Assert + assert document.tenant_id == tenant_id + assert document.dataset_id == dataset_id + assert document.position == 1 + assert document.data_source_type == "upload_file" + assert document.batch == "batch_001" + assert document.name == "test_document.pdf" + assert document.created_from == "web" + assert document.created_by == created_by + # Note: Default values are set by database, not by model instantiation + + def test_document_data_source_types(self): + """Test document data source type validation.""" + # Assert + assert "upload_file" in Document.DATA_SOURCES + assert "notion_import" in Document.DATA_SOURCES + assert "website_crawl" in Document.DATA_SOURCES + + def test_document_display_status_queuing(self): + """Test document display_status property for queuing state.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + indexing_status="waiting", + ) + + # Act + status = document.display_status + + # Assert + assert status == "queuing" + + def test_document_display_status_paused(self): + """Test document display_status property for paused state.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + indexing_status="parsing", + is_paused=True, + ) + + # Act + status = document.display_status + + # Assert + assert status == "paused" + + def test_document_display_status_indexing(self): + """Test document display_status property for indexing state.""" + # Arrange + for indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + indexing_status=indexing_status, + ) + + # Act + status = document.display_status + + # Assert + assert status == "indexing" + + def test_document_display_status_error(self): + """Test document display_status property for error state.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + indexing_status="error", + ) + + # Act + status = document.display_status + + # Assert + assert status == "error" + + def test_document_display_status_available(self): + """Test document display_status property for available state.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + indexing_status="completed", + enabled=True, + archived=False, + ) + + # Act + status = document.display_status + + # Assert + assert status == "available" + + def test_document_display_status_disabled(self): + """Test document display_status property for disabled state.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + indexing_status="completed", + enabled=False, + archived=False, + ) + + # Act + status = document.display_status + + # Assert + assert status == "disabled" + + def test_document_display_status_archived(self): + """Test document display_status property for archived state.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + indexing_status="completed", + archived=True, + ) + + # Act + status = document.display_status + + # Assert + assert status == "archived" + + def test_document_data_source_info_dict_property(self): + """Test data_source_info_dict property parsing.""" + # Arrange + data_source_info = {"upload_file_id": str(uuid4()), "file_name": "test.pdf"} + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + data_source_info=json.dumps(data_source_info), + ) + + # Act + result = document.data_source_info_dict + + # Assert + assert result == data_source_info + assert "upload_file_id" in result + assert "file_name" in result + + def test_document_data_source_info_dict_property_empty(self): + """Test data_source_info_dict property when data_source_info is None.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + ) + + # Act + result = document.data_source_info_dict + + # Assert + assert result == {} + + def test_document_average_segment_length(self): + """Test average_segment_length property calculation.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + word_count=1000, + ) + + # Mock segment_count property + with patch.object(Document, "segment_count", new_callable=lambda: property(lambda self: 10)): + # Act + result = document.average_segment_length + + # Assert + assert result == 100 + + def test_document_average_segment_length_zero(self): + """Test average_segment_length property when word_count is zero.""" + # Arrange + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + word_count=0, + ) + + # Act + result = document.average_segment_length + + # Assert + assert result == 0 + + +class TestDocumentSegmentIndexing: + """Test suite for DocumentSegment model indexing and operations.""" + + def test_document_segment_creation_with_required_fields(self): + """Test creating a document segment with all required fields.""" + # Arrange + tenant_id = str(uuid4()) + dataset_id = str(uuid4()) + document_id = str(uuid4()) + created_by = str(uuid4()) + + # Act + segment = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + position=1, + content="This is a test segment content.", + word_count=6, + tokens=10, + created_by=created_by, + ) + + # Assert + assert segment.tenant_id == tenant_id + assert segment.dataset_id == dataset_id + assert segment.document_id == document_id + assert segment.position == 1 + assert segment.content == "This is a test segment content." + assert segment.word_count == 6 + assert segment.tokens == 10 + assert segment.created_by == created_by + # Note: Default values are set by database, not by model instantiation + + def test_document_segment_with_indexing_fields(self): + """Test creating a document segment with indexing fields.""" + # Arrange + index_node_id = str(uuid4()) + index_node_hash = "abc123hash" + keywords = ["test", "segment", "indexing"] + + # Act + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=str(uuid4()), + position=1, + content="Test content", + word_count=2, + tokens=5, + created_by=str(uuid4()), + index_node_id=index_node_id, + index_node_hash=index_node_hash, + keywords=keywords, + ) + + # Assert + assert segment.index_node_id == index_node_id + assert segment.index_node_hash == index_node_hash + assert segment.keywords == keywords + + def test_document_segment_with_answer_field(self): + """Test creating a document segment with answer field for QA model.""" + # Arrange + content = "What is AI?" + answer = "AI stands for Artificial Intelligence." + + # Act + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=str(uuid4()), + position=1, + content=content, + answer=answer, + word_count=3, + tokens=8, + created_by=str(uuid4()), + ) + + # Assert + assert segment.content == content + assert segment.answer == answer + + def test_document_segment_status_transitions(self): + """Test document segment status field values.""" + # Arrange & Act + segment_waiting = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=str(uuid4()), + position=1, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + status="waiting", + ) + segment_completed = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=str(uuid4()), + position=1, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + status="completed", + ) + + # Assert + assert segment_waiting.status == "waiting" + assert segment_completed.status == "completed" + + def test_document_segment_enabled_disabled_tracking(self): + """Test document segment enabled/disabled state tracking.""" + # Arrange + disabled_by = str(uuid4()) + disabled_at = datetime.now(UTC) + + # Act + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=str(uuid4()), + position=1, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + enabled=False, + disabled_by=disabled_by, + disabled_at=disabled_at, + ) + + # Assert + assert segment.enabled is False + assert segment.disabled_by == disabled_by + assert segment.disabled_at == disabled_at + + def test_document_segment_hit_count_tracking(self): + """Test document segment hit count tracking.""" + # Arrange & Act + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=str(uuid4()), + position=1, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + hit_count=5, + ) + + # Assert + assert segment.hit_count == 5 + + def test_document_segment_error_tracking(self): + """Test document segment error tracking.""" + # Arrange + error_message = "Indexing failed due to timeout" + stopped_at = datetime.now(UTC) + + # Act + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=str(uuid4()), + position=1, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + error=error_message, + stopped_at=stopped_at, + ) + + # Assert + assert segment.error == error_message + assert segment.stopped_at == stopped_at + + +class TestEmbeddingStorage: + """Test suite for Embedding model storage and retrieval.""" + + def test_embedding_creation_with_required_fields(self): + """Test creating an embedding with required fields.""" + # Arrange + model_name = "text-embedding-ada-002" + hash_value = "abc123hash" + provider_name = "openai" + + # Act + embedding = Embedding( + model_name=model_name, + hash=hash_value, + provider_name=provider_name, + embedding=b"binary_data", + ) + + # Assert + assert embedding.model_name == model_name + assert embedding.hash == hash_value + assert embedding.provider_name == provider_name + assert embedding.embedding == b"binary_data" + + def test_embedding_set_and_get_embedding(self): + """Test setting and getting embedding data.""" + # Arrange + embedding_data = [0.1, 0.2, 0.3, 0.4, 0.5] + embedding = Embedding( + model_name="text-embedding-ada-002", + hash="test_hash", + provider_name="openai", + embedding=b"", + ) + + # Act + embedding.set_embedding(embedding_data) + retrieved_data = embedding.get_embedding() + + # Assert + assert retrieved_data == embedding_data + assert len(retrieved_data) == 5 + assert retrieved_data[0] == 0.1 + assert retrieved_data[4] == 0.5 + + def test_embedding_pickle_serialization(self): + """Test embedding data is properly pickled.""" + # Arrange + embedding_data = [0.1, 0.2, 0.3] + embedding = Embedding( + model_name="text-embedding-ada-002", + hash="test_hash", + provider_name="openai", + embedding=b"", + ) + + # Act + embedding.set_embedding(embedding_data) + + # Assert + # Verify the embedding is stored as pickled binary data + assert isinstance(embedding.embedding, bytes) + # Verify we can unpickle it + unpickled_data = pickle.loads(embedding.embedding) # noqa: S301 + assert unpickled_data == embedding_data + + def test_embedding_with_large_vector(self): + """Test embedding with large dimension vector.""" + # Arrange + # Simulate a 1536-dimension vector (OpenAI ada-002 size) + large_embedding_data = [0.001 * i for i in range(1536)] + embedding = Embedding( + model_name="text-embedding-ada-002", + hash="large_vector_hash", + provider_name="openai", + embedding=b"", + ) + + # Act + embedding.set_embedding(large_embedding_data) + retrieved_data = embedding.get_embedding() + + # Assert + assert len(retrieved_data) == 1536 + assert retrieved_data[0] == 0.0 + assert abs(retrieved_data[1535] - 1.535) < 0.0001 # Float comparison with tolerance + + +class TestDatasetProcessRule: + """Test suite for DatasetProcessRule model.""" + + def test_dataset_process_rule_creation(self): + """Test creating a dataset process rule.""" + # Arrange + dataset_id = str(uuid4()) + created_by = str(uuid4()) + + # Act + process_rule = DatasetProcessRule( + dataset_id=dataset_id, + mode="automatic", + created_by=created_by, + ) + + # Assert + assert process_rule.dataset_id == dataset_id + assert process_rule.mode == "automatic" + assert process_rule.created_by == created_by + + def test_dataset_process_rule_modes(self): + """Test dataset process rule mode validation.""" + # Assert + assert "automatic" in DatasetProcessRule.MODES + assert "custom" in DatasetProcessRule.MODES + assert "hierarchical" in DatasetProcessRule.MODES + + def test_dataset_process_rule_with_rules_dict(self): + """Test dataset process rule with rules dictionary.""" + # Arrange + rules_data = { + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, + ], + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + } + process_rule = DatasetProcessRule( + dataset_id=str(uuid4()), + mode="custom", + created_by=str(uuid4()), + rules=json.dumps(rules_data), + ) + + # Act + result = process_rule.rules_dict + + # Assert + assert result == rules_data + assert "pre_processing_rules" in result + assert "segmentation" in result + + def test_dataset_process_rule_to_dict(self): + """Test dataset process rule to_dict method.""" + # Arrange + dataset_id = str(uuid4()) + rules_data = {"test": "data"} + process_rule = DatasetProcessRule( + dataset_id=dataset_id, + mode="automatic", + created_by=str(uuid4()), + rules=json.dumps(rules_data), + ) + + # Act + result = process_rule.to_dict() + + # Assert + assert result["dataset_id"] == dataset_id + assert result["mode"] == "automatic" + assert result["rules"] == rules_data + + def test_dataset_process_rule_automatic_rules(self): + """Test dataset process rule automatic rules constant.""" + # Act + automatic_rules = DatasetProcessRule.AUTOMATIC_RULES + + # Assert + assert "pre_processing_rules" in automatic_rules + assert "segmentation" in automatic_rules + assert automatic_rules["segmentation"]["max_tokens"] == 500 + + +class TestDatasetKeywordTable: + """Test suite for DatasetKeywordTable model.""" + + def test_dataset_keyword_table_creation(self): + """Test creating a dataset keyword table.""" + # Arrange + dataset_id = str(uuid4()) + keyword_data = {"test": ["node1", "node2"], "keyword": ["node3"]} + + # Act + keyword_table = DatasetKeywordTable( + dataset_id=dataset_id, + keyword_table=json.dumps(keyword_data), + ) + + # Assert + assert keyword_table.dataset_id == dataset_id + assert keyword_table.data_source_type == "database" # Default value + + def test_dataset_keyword_table_data_source_type(self): + """Test dataset keyword table data source type.""" + # Arrange & Act + keyword_table = DatasetKeywordTable( + dataset_id=str(uuid4()), + keyword_table="{}", + data_source_type="file", + ) + + # Assert + assert keyword_table.data_source_type == "file" + + +class TestAppDatasetJoin: + """Test suite for AppDatasetJoin model.""" + + def test_app_dataset_join_creation(self): + """Test creating an app-dataset join relationship.""" + # Arrange + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + # Act + join = AppDatasetJoin( + app_id=app_id, + dataset_id=dataset_id, + ) + + # Assert + assert join.app_id == app_id + assert join.dataset_id == dataset_id + # Note: ID is auto-generated when saved to database + + +class TestChildChunk: + """Test suite for ChildChunk model.""" + + def test_child_chunk_creation(self): + """Test creating a child chunk.""" + # Arrange + tenant_id = str(uuid4()) + dataset_id = str(uuid4()) + document_id = str(uuid4()) + segment_id = str(uuid4()) + created_by = str(uuid4()) + + # Act + child_chunk = ChildChunk( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + segment_id=segment_id, + position=1, + content="Child chunk content", + word_count=3, + created_by=created_by, + ) + + # Assert + assert child_chunk.tenant_id == tenant_id + assert child_chunk.dataset_id == dataset_id + assert child_chunk.document_id == document_id + assert child_chunk.segment_id == segment_id + assert child_chunk.position == 1 + assert child_chunk.content == "Child chunk content" + assert child_chunk.word_count == 3 + assert child_chunk.created_by == created_by + # Note: Default values are set by database, not by model instantiation + + def test_child_chunk_with_indexing_fields(self): + """Test creating a child chunk with indexing fields.""" + # Arrange + index_node_id = str(uuid4()) + index_node_hash = "child_hash_123" + + # Act + child_chunk = ChildChunk( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=str(uuid4()), + segment_id=str(uuid4()), + position=1, + content="Test content", + word_count=2, + created_by=str(uuid4()), + index_node_id=index_node_id, + index_node_hash=index_node_hash, + ) + + # Assert + assert child_chunk.index_node_id == index_node_id + assert child_chunk.index_node_hash == index_node_hash + + +class TestDatasetDocumentCascadeDeletes: + """Test suite for Dataset-Document cascade delete operations.""" + + def test_dataset_with_documents_relationship(self): + """Test dataset can track its documents.""" + # Arrange + dataset_id = str(uuid4()) + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + dataset.id = dataset_id + + # Mock the database session query + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = 3 + + with patch("models.dataset.db.session.query", return_value=mock_query): + # Act + total_docs = dataset.total_documents + + # Assert + assert total_docs == 3 + + def test_dataset_available_documents_count(self): + """Test dataset can count available documents.""" + # Arrange + dataset_id = str(uuid4()) + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + dataset.id = dataset_id + + # Mock the database session query + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = 2 + + with patch("models.dataset.db.session.query", return_value=mock_query): + # Act + available_docs = dataset.total_available_documents + + # Assert + assert available_docs == 2 + + def test_dataset_word_count_aggregation(self): + """Test dataset can aggregate word count from documents.""" + # Arrange + dataset_id = str(uuid4()) + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + dataset.id = dataset_id + + # Mock the database session query + mock_query = MagicMock() + mock_query.with_entities.return_value.where.return_value.scalar.return_value = 5000 + + with patch("models.dataset.db.session.query", return_value=mock_query): + # Act + total_words = dataset.word_count + + # Assert + assert total_words == 5000 + + def test_dataset_available_segment_count(self): + """Test dataset can count available segments.""" + # Arrange + dataset_id = str(uuid4()) + dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + dataset.id = dataset_id + + # Mock the database session query + mock_query = MagicMock() + mock_query.where.return_value.scalar.return_value = 15 + + with patch("models.dataset.db.session.query", return_value=mock_query): + # Act + segment_count = dataset.available_segment_count + + # Assert + assert segment_count == 15 + + def test_document_segment_count_property(self): + """Test document can count its segments.""" + # Arrange + document_id = str(uuid4()) + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + ) + document.id = document_id + + # Mock the database session query + mock_query = MagicMock() + mock_query.where.return_value.count.return_value = 10 + + with patch("models.dataset.db.session.query", return_value=mock_query): + # Act + segment_count = document.segment_count + + # Assert + assert segment_count == 10 + + def test_document_hit_count_aggregation(self): + """Test document can aggregate hit count from segments.""" + # Arrange + document_id = str(uuid4()) + document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + ) + document.id = document_id + + # Mock the database session query + mock_query = MagicMock() + mock_query.with_entities.return_value.where.return_value.scalar.return_value = 25 + + with patch("models.dataset.db.session.query", return_value=mock_query): + # Act + hit_count = document.hit_count + + # Assert + assert hit_count == 25 + + +class TestDocumentSegmentNavigation: + """Test suite for DocumentSegment navigation properties.""" + + def test_document_segment_dataset_property(self): + """Test segment can access its parent dataset.""" + # Arrange + dataset_id = str(uuid4()) + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=dataset_id, + document_id=str(uuid4()), + position=1, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + ) + + mock_dataset = Dataset( + tenant_id=str(uuid4()), + name="Test Dataset", + data_source_type="upload_file", + created_by=str(uuid4()), + ) + mock_dataset.id = dataset_id + + # Mock the database session scalar + with patch("models.dataset.db.session.scalar", return_value=mock_dataset): + # Act + dataset = segment.dataset + + # Assert + assert dataset is not None + assert dataset.id == dataset_id + + def test_document_segment_document_property(self): + """Test segment can access its parent document.""" + # Arrange + document_id = str(uuid4()) + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=document_id, + position=1, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + ) + + mock_document = Document( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=str(uuid4()), + ) + mock_document.id = document_id + + # Mock the database session scalar + with patch("models.dataset.db.session.scalar", return_value=mock_document): + # Act + document = segment.document + + # Assert + assert document is not None + assert document.id == document_id + + def test_document_segment_previous_segment(self): + """Test segment can access previous segment.""" + # Arrange + document_id = str(uuid4()) + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=document_id, + position=2, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + ) + + previous_segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=document_id, + position=1, + content="Previous", + word_count=1, + tokens=2, + created_by=str(uuid4()), + ) + + # Mock the database session scalar + with patch("models.dataset.db.session.scalar", return_value=previous_segment): + # Act + prev_seg = segment.previous_segment + + # Assert + assert prev_seg is not None + assert prev_seg.position == 1 + + def test_document_segment_next_segment(self): + """Test segment can access next segment.""" + # Arrange + document_id = str(uuid4()) + segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=document_id, + position=1, + content="Test", + word_count=1, + tokens=2, + created_by=str(uuid4()), + ) + + next_segment = DocumentSegment( + tenant_id=str(uuid4()), + dataset_id=str(uuid4()), + document_id=document_id, + position=2, + content="Next", + word_count=1, + tokens=2, + created_by=str(uuid4()), + ) + + # Mock the database session scalar + with patch("models.dataset.db.session.scalar", return_value=next_segment): + # Act + next_seg = segment.next_segment + + # Assert + assert next_seg is not None + assert next_seg.position == 2 + + +class TestModelIntegration: + """Test suite for model integration scenarios.""" + + def test_complete_dataset_document_segment_hierarchy(self): + """Test complete hierarchy from dataset to segment.""" + # Arrange + tenant_id = str(uuid4()) + dataset_id = str(uuid4()) + document_id = str(uuid4()) + created_by = str(uuid4()) + + # Create dataset + dataset = Dataset( + tenant_id=tenant_id, + name="Test Dataset", + data_source_type="upload_file", + created_by=created_by, + indexing_technique="high_quality", + ) + dataset.id = dataset_id + + # Create document + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=created_by, + word_count=100, + ) + document.id = document_id + + # Create segment + segment = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + position=1, + content="Test segment content", + word_count=3, + tokens=5, + created_by=created_by, + status="completed", + ) + + # Assert + assert dataset.id == dataset_id + assert document.dataset_id == dataset_id + assert segment.dataset_id == dataset_id + assert segment.document_id == document_id + assert dataset.indexing_technique == "high_quality" + assert document.word_count == 100 + assert segment.status == "completed" + + def test_document_to_dict_serialization(self): + """Test document to_dict method for serialization.""" + # Arrange + tenant_id = str(uuid4()) + dataset_id = str(uuid4()) + created_by = str(uuid4()) + + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type="upload_file", + batch="batch_001", + name="test.pdf", + created_from="web", + created_by=created_by, + word_count=100, + indexing_status="completed", + ) + + # Mock segment_count and hit_count + with ( + patch.object(Document, "segment_count", new_callable=lambda: property(lambda self: 5)), + patch.object(Document, "hit_count", new_callable=lambda: property(lambda self: 10)), + ): + # Act + result = document.to_dict() + + # Assert + assert result["tenant_id"] == tenant_id + assert result["dataset_id"] == dataset_id + assert result["name"] == "test.pdf" + assert result["word_count"] == 100 + assert result["indexing_status"] == "completed" + assert result["segment_count"] == 5 + assert result["hit_count"] == 10