test: unit test for core.rag module (#32630)

This commit is contained in:
rajatagarwal-oss 2026-03-10 11:40:24 +05:30 committed by GitHub
parent a5832df586
commit a0ed350871
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 10255 additions and 1249 deletions

View File

@ -0,0 +1,813 @@
"""
Unit tests for DatasetDocumentStore.
Tests cover all public methods and error paths of the DatasetDocumentStore class
which provides document storage and retrieval functionality for datasets in the RAG system.
"""
from unittest.mock import MagicMock, patch
import pytest
from core.rag.docstore.dataset_docstore import DatasetDocumentStore, DocumentSegment
from core.rag.models.document import AttachmentDocument, Document
from models.dataset import Dataset
class TestDatasetDocumentStoreInit:
"""Tests for DatasetDocumentStore initialization."""
def test_init_with_all_parameters(self):
"""Test initialization with dataset, user_id, and document_id."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
assert store._dataset == mock_dataset
assert store._user_id == "test-user-id"
assert store._document_id == "test-doc-id"
assert store.dataset_id == "test-dataset-id"
assert store.user_id == "test-user-id"
def test_init_without_document_id(self):
"""Test initialization without document_id."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
assert store._document_id is None
assert store.dataset_id == "test-dataset-id"
class TestDatasetDocumentStoreSerialization:
"""Tests for to_dict and from_dict methods."""
def test_to_dict(self):
"""Test serialization to dictionary."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.to_dict()
assert result == {"dataset_id": "test-dataset-id"}
def test_from_dict(self):
"""Test deserialization from dictionary."""
config_dict = {
"dataset": MagicMock(spec=["id"]),
"user_id": "test-user",
"document_id": "test-doc",
}
config_dict["dataset"].id = "ds-123"
store = DatasetDocumentStore.from_dict(config_dict)
assert store._user_id == "test-user"
assert store._document_id == "test-doc"
class TestDatasetDocumentStoreDocs:
"""Tests for the docs property."""
def test_docs_returns_document_dict(self):
"""Test that docs property returns a dictionary of documents."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_segment = MagicMock(spec=DocumentSegment)
mock_segment.index_node_id = "node-1"
mock_segment.index_node_hash = "hash-1"
mock_segment.document_id = "doc-1"
mock_segment.dataset_id = "test-dataset-id"
mock_segment.content = "Test content"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.scalars.return_value.all.return_value = [mock_segment]
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.docs
assert "node-1" in result
assert isinstance(result["node-1"], Document)
def test_docs_empty_dataset(self):
"""Test docs property with no segments."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.scalars.return_value.all.return_value = []
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.docs
assert result == {}
class TestDatasetDocumentStoreAddDocuments:
"""Tests for add_documents method."""
def test_add_documents_new_document_with_embedding(self):
"""Test adding new documents with embedding model."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_dataset.indexing_technique = "high_quality"
mock_dataset.embedding_model_provider = "provider"
mock_dataset.embedding_model = "model"
mock_doc = MagicMock(spec=Document)
mock_doc.page_content = "Test content"
mock_doc.metadata = {
"doc_id": "doc-1",
"doc_hash": "hash-1",
}
mock_doc.attachments = None
mock_doc.children = None
mock_model_instance = MagicMock()
mock_model_instance.get_text_embedding_num_tokens.return_value = [10]
with (
patch("core.rag.docstore.dataset_docstore.db") as mock_db,
patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class,
):
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
mock_manager = MagicMock()
mock_manager.get_model_instance.return_value = mock_model_instance
mock_manager_class.return_value = mock_manager
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_documents([mock_doc])
mock_db.session.add.assert_called()
mock_db.session.commit.assert_called()
def test_add_documents_update_existing_document(self):
"""Test updating existing document with allow_update=True."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_dataset.indexing_technique = "economy"
mock_dataset.embedding_model_provider = None
mock_dataset.embedding_model = None
mock_doc = MagicMock(spec=Document)
mock_doc.page_content = "Updated content"
mock_doc.metadata = {
"doc_id": "doc-1",
"doc_hash": "new-hash",
}
mock_doc.attachments = None
mock_doc.children = None
mock_existing_segment = MagicMock()
mock_existing_segment.id = "seg-1"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_documents([mock_doc])
mock_db.session.commit.assert_called()
def test_add_documents_raises_when_not_allowed(self):
"""Test that adding existing doc without allow_update raises ValueError."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_dataset.indexing_technique = "economy"
mock_doc = MagicMock(spec=Document)
mock_doc.page_content = "Test content"
mock_doc.metadata = {
"doc_id": "doc-1",
"doc_hash": "hash-1",
}
mock_doc.attachments = None
mock_doc.children = None
mock_existing_segment = MagicMock()
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
with pytest.raises(ValueError, match="already exists"):
store.add_documents([mock_doc], allow_update=False)
def test_add_documents_with_answer_metadata(self):
"""Test adding document with answer in metadata."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_dataset.indexing_technique = "economy"
mock_doc = MagicMock(spec=Document)
mock_doc.page_content = "Test content"
mock_doc.metadata = {
"doc_id": "doc-1",
"doc_hash": "hash-1",
"answer": "Test answer",
}
mock_doc.attachments = None
mock_doc.children = None
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_documents([mock_doc])
mock_db.session.add.assert_called()
def test_add_documents_with_invalid_document_type(self):
"""Test that non-Document raises ValueError."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db"):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
with pytest.raises(ValueError, match="must be a Document"):
store.add_documents(["not a document"])
def test_add_documents_with_none_metadata(self):
"""Test that document with None metadata raises ValueError."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_doc = MagicMock(spec=Document)
mock_doc.page_content = "Test content"
mock_doc.metadata = None
with patch("core.rag.docstore.dataset_docstore.db"):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
with pytest.raises(ValueError, match="metadata must be a dict"):
store.add_documents([mock_doc])
def test_add_documents_with_save_child(self):
"""Test adding documents with save_child=True."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_dataset.indexing_technique = "economy"
mock_child = MagicMock(spec=Document)
mock_child.page_content = "Child content"
mock_child.metadata = {
"doc_id": "child-1",
"doc_hash": "child-hash",
}
mock_doc = MagicMock(spec=Document)
mock_doc.page_content = "Test content"
mock_doc.metadata = {
"doc_id": "doc-1",
"doc_hash": "hash-1",
}
mock_doc.attachments = None
mock_doc.children = [mock_child]
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_documents([mock_doc], save_child=True)
mock_db.session.add.assert_called()
class TestDatasetDocumentStoreExists:
"""Tests for document_exists method."""
def test_document_exists_returns_true(self):
"""Test document_exists returns True when segment exists."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_segment = MagicMock()
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.document_exists("doc-1")
assert result is True
def test_document_exists_returns_false(self):
"""Test document_exists returns False when segment doesn't exist."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.document_exists("doc-1")
assert result is False
class TestDatasetDocumentStoreGetDocument:
"""Tests for get_document method."""
def test_get_document_success(self):
"""Test getting a document successfully."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_segment = MagicMock(spec=DocumentSegment)
mock_segment.index_node_id = "node-1"
mock_segment.index_node_hash = "hash-1"
mock_segment.document_id = "doc-1"
mock_segment.dataset_id = "test-dataset-id"
mock_segment.content = "Test content"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.get_document("node-1", raise_error=False)
assert isinstance(result, Document)
assert result.page_content == "Test content"
def test_get_document_returns_none_when_not_found(self):
"""Test get_document returns None when not found and raise_error=False."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.get_document("nonexistent", raise_error=False)
assert result is None
def test_get_document_raises_when_not_found(self):
"""Test get_document raises ValueError when not found and raise_error=True."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
with pytest.raises(ValueError, match="not found"):
store.get_document("nonexistent", raise_error=True)
class TestDatasetDocumentStoreDeleteDocument:
"""Tests for delete_document method."""
def test_delete_document_success(self):
"""Test deleting a document successfully."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_segment = MagicMock()
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
store.delete_document("doc-1")
mock_db.session.delete.assert_called_with(mock_segment)
mock_db.session.commit.assert_called()
def test_delete_document_returns_none_when_not_found(self):
"""Test delete_document returns None when not found and raise_error=False."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.delete_document("nonexistent", raise_error=False)
assert result is None
def test_delete_document_raises_when_not_found(self):
"""Test delete_document raises ValueError when not found and raise_error=True."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
with pytest.raises(ValueError, match="not found"):
store.delete_document("nonexistent", raise_error=True)
class TestDatasetDocumentStoreHashOperations:
"""Tests for set_document_hash and get_document_hash methods."""
def test_set_document_hash_success(self):
"""Test setting document hash successfully."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_segment = MagicMock()
mock_segment.index_node_hash = "old-hash"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
store.set_document_hash("doc-1", "new-hash")
assert mock_segment.index_node_hash == "new-hash"
mock_db.session.commit.assert_called()
def test_set_document_hash_returns_none_when_not_found(self):
"""Test set_document_hash returns None when segment not found."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.set_document_hash("nonexistent", "new-hash")
assert result is None
def test_get_document_hash_success(self):
"""Test getting document hash successfully."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_segment = MagicMock()
mock_segment.index_node_hash = "test-hash"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.get_document_hash("doc-1")
assert result == "test-hash"
def test_get_document_hash_returns_none_when_not_found(self):
"""Test get_document_hash returns None when segment not found."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db"):
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.get_document_hash("nonexistent")
assert result is None
class TestDatasetDocumentStoreSegment:
"""Tests for get_document_segment method."""
def test_get_document_segment_returns_segment(self):
"""Test getting a document segment."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_segment = MagicMock(spec=DocumentSegment)
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.scalar.return_value = mock_segment
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.get_document_segment("doc-1")
assert result == mock_segment
def test_get_document_segment_returns_none(self):
"""Test getting a non-existent document segment."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.scalar.return_value = None
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
)
result = store.get_document_segment("nonexistent")
assert result is None
class TestDatasetDocumentStoreMultimodelBinding:
"""Tests for add_multimodel_documents_binding method."""
def test_add_multimodel_documents_binding_with_attachments(self):
"""Test adding multimodel document bindings."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_attachment = MagicMock(spec=AttachmentDocument)
mock_attachment.metadata = {"doc_id": "attachment-1"}
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_multimodel_documents_binding("seg-1", [mock_attachment])
mock_db.session.add.assert_called()
def test_add_multimodel_documents_binding_without_attachments(self):
"""Test adding bindings with None attachments."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_multimodel_documents_binding("seg-1", None)
mock_db.session.add.assert_not_called()
def test_add_multimodel_documents_binding_with_empty_list(self):
"""Test adding bindings with empty list."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_multimodel_documents_binding("seg-1", [])
mock_db.session.add.assert_not_called()
class TestDatasetDocumentStoreAddDocumentsUpdateChild:
"""Tests for add_documents when updating existing documents with children."""
def test_add_documents_update_existing_with_children(self):
"""Test updating existing document with save_child=True and children."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_dataset.indexing_technique = "economy"
mock_child = MagicMock(spec=Document)
mock_child.page_content = "Updated child content"
mock_child.metadata = {
"doc_id": "child-1",
"doc_hash": "new-child-hash",
}
mock_doc = MagicMock(spec=Document)
mock_doc.page_content = "Updated content"
mock_doc.metadata = {
"doc_id": "doc-1",
"doc_hash": "new-hash",
}
mock_doc.attachments = None
mock_doc.children = [mock_child]
mock_existing_segment = MagicMock()
mock_existing_segment.id = "seg-1"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_documents([mock_doc], save_child=True)
mock_db.session.query.return_value.where.return_value.delete.assert_called()
mock_db.session.commit.assert_called()
class TestDatasetDocumentStoreAddDocumentsUpdateAnswer:
"""Tests for add_documents when updating existing documents with answer metadata."""
def test_add_documents_update_existing_with_answer(self):
"""Test updating existing document with answer in metadata."""
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "test-dataset-id"
mock_dataset.tenant_id = "tenant-1"
mock_dataset.indexing_technique = "economy"
mock_doc = MagicMock(spec=Document)
mock_doc.page_content = "Updated content"
mock_doc.metadata = {
"doc_id": "doc-1",
"doc_hash": "new-hash",
"answer": "Updated answer",
}
mock_doc.attachments = None
mock_doc.children = None
mock_existing_segment = MagicMock()
mock_existing_segment.id = "seg-1"
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
store = DatasetDocumentStore(
dataset=mock_dataset,
user_id="test-user-id",
document_id="test-doc-id",
)
store.add_documents([mock_doc])
mock_db.session.commit.assert_called()

View File

@ -0,0 +1,555 @@
"""Unit tests for cached_embedding.py - CacheEmbedding class.
This test file covers the methods not fully tested in test_embedding_service.py:
- embed_multimodal_documents
- embed_multimodal_query
- Error handling scenarios in embed_query (DEBUG mode)
"""
import base64
from decimal import Decimal
from unittest.mock import Mock, patch
import numpy as np
import pytest
from sqlalchemy.exc import IntegrityError
from core.rag.embedding.cached_embedding import CacheEmbedding
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
from models.dataset import Embedding
class TestCacheEmbeddingMultimodalDocuments:
"""Test suite for CacheEmbedding.embed_multimodal_documents method."""
@pytest.fixture
def mock_model_instance(self):
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "vision-embedding-model"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
model_type_instance = Mock()
model_instance.model_type_instance = model_type_instance
model_schema = Mock()
model_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10}
model_type_instance.get_model_schema.return_value = model_schema
return model_instance
@pytest.fixture
def sample_multimodal_result(self):
"""Create a sample multimodal EmbeddingResult."""
embedding_vector = np.random.randn(1536)
normalized_vector = (embedding_vector / np.linalg.norm(embedding_vector)).tolist()
usage = EmbeddingUsage(
tokens=10,
total_tokens=10,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal("0.000001"),
currency="USD",
latency=0.5,
)
return EmbeddingResult(
model="vision-embedding-model",
embeddings=[normalized_vector],
usage=usage,
)
def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result):
"""Test embedding a single multimodal document when cache is empty."""
cache_embedding = CacheEmbedding(mock_model_instance, user="test-user")
documents = [{"file_id": "file123", "content": "test content"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
result = cache_embedding.embed_multimodal_documents(documents)
assert len(result) == 1
assert isinstance(result[0], list)
assert len(result[0]) == 1536
mock_model_instance.invoke_multimodal_embedding.assert_called_once()
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_embed_multiple_multimodal_documents_cache_miss(self, mock_model_instance):
"""Test embedding multiple multimodal documents when cache is empty."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [
{"file_id": "file1", "content": "content 1"},
{"file_id": "file2", "content": "content 2"},
{"file_id": "file3", "content": "content 3"},
]
embeddings = []
for _ in range(3):
vector = np.random.randn(1536)
normalized = (vector / np.linalg.norm(vector)).tolist()
embeddings.append(normalized)
usage = EmbeddingUsage(
tokens=30,
total_tokens=30,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal("0.000003"),
currency="USD",
latency=0.8,
)
embedding_result = EmbeddingResult(
model="vision-embedding-model",
embeddings=embeddings,
usage=usage,
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_documents(documents)
assert len(result) == 3
assert all(len(emb) == 1536 for emb in result)
def test_embed_multimodal_documents_cache_hit(self, mock_model_instance):
"""Test embedding multimodal documents when embeddings are cached."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [{"file_id": "file123"}]
cached_vector = np.random.randn(1536)
normalized_cached = (cached_vector / np.linalg.norm(cached_vector)).tolist()
mock_cached_embedding = Mock(spec=Embedding)
mock_cached_embedding.get_embedding.return_value = normalized_cached
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
result = cache_embedding.embed_multimodal_documents(documents)
assert len(result) == 1
assert result[0] == normalized_cached
mock_model_instance.invoke_multimodal_embedding.assert_not_called()
def test_embed_multimodal_documents_partial_cache_hit(self, mock_model_instance):
"""Test embedding multimodal documents with mixed cache hits and misses."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [
{"file_id": "cached_file"},
{"file_id": "new_file_1"},
{"file_id": "new_file_2"},
]
cached_vector = np.random.randn(1536)
normalized_cached = (cached_vector / np.linalg.norm(cached_vector)).tolist()
mock_cached_embedding = Mock(spec=Embedding)
mock_cached_embedding.get_embedding.return_value = normalized_cached
new_embeddings = []
for _ in range(2):
vector = np.random.randn(1536)
normalized = (vector / np.linalg.norm(vector)).tolist()
new_embeddings.append(normalized)
usage = EmbeddingUsage(
tokens=20,
total_tokens=20,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal("0.000002"),
currency="USD",
latency=0.6,
)
embedding_result = EmbeddingResult(
model="vision-embedding-model",
embeddings=new_embeddings,
usage=usage,
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
call_count = [0]
def mock_filter_by(**kwargs):
call_count[0] += 1
mock_query = Mock()
if call_count[0] == 1:
mock_query.first.return_value = mock_cached_embedding
else:
mock_query.first.return_value = None
return mock_query
mock_session.query.return_value.filter_by = mock_filter_by
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_documents(documents)
assert len(result) == 3
assert result[0] == normalized_cached
def test_embed_multimodal_documents_nan_handling(self, mock_model_instance):
"""Test handling of NaN values in multimodal embeddings."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [{"file_id": "valid"}, {"file_id": "nan"}]
valid_vector = np.random.randn(1536).tolist()
nan_vector = [float("nan")] * 1536
usage = EmbeddingUsage(
tokens=20,
total_tokens=20,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal("0.000002"),
currency="USD",
latency=0.5,
)
embedding_result = EmbeddingResult(
model="vision-embedding-model",
embeddings=[valid_vector, nan_vector],
usage=usage,
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
result = cache_embedding.embed_multimodal_documents(documents)
assert len(result) == 2
assert result[0] is not None
assert result[1] is None
mock_logger.warning.assert_called_once()
def test_embed_multimodal_documents_large_batch(self, mock_model_instance):
"""Test embedding large batch of multimodal documents respecting MAX_CHUNKS."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [{"file_id": f"file{i}"} for i in range(25)]
def create_batch_result(batch_size):
embeddings = []
for _ in range(batch_size):
vector = np.random.randn(1536)
normalized = (vector / np.linalg.norm(vector)).tolist()
embeddings.append(normalized)
usage = EmbeddingUsage(
tokens=batch_size * 10,
total_tokens=batch_size * 10,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal(str(batch_size * 0.000001)),
currency="USD",
latency=0.5,
)
return EmbeddingResult(
model="vision-embedding-model",
embeddings=embeddings,
usage=usage,
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)]
mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results
result = cache_embedding.embed_multimodal_documents(documents)
assert len(result) == 25
assert mock_model_instance.invoke_multimodal_embedding.call_count == 3
def test_embed_multimodal_documents_api_error(self, mock_model_instance):
"""Test handling of API errors during multimodal embedding."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [{"file_id": "file123"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
with pytest.raises(Exception) as exc_info:
cache_embedding.embed_multimodal_documents(documents)
assert "API Error" in str(exc_info.value)
mock_session.rollback.assert_called()
def test_embed_multimodal_documents_integrity_error_during_transform(
self, mock_model_instance, sample_multimodal_result
):
"""Test handling of IntegrityError during embedding transformation."""
cache_embedding = CacheEmbedding(mock_model_instance)
documents = [{"file_id": "file123"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None)
result = cache_embedding.embed_multimodal_documents(documents)
assert len(result) == 1
mock_session.rollback.assert_called()
class TestCacheEmbeddingMultimodalQuery:
"""Test suite for CacheEmbedding.embed_multimodal_query method."""
@pytest.fixture
def mock_model_instance(self):
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "vision-embedding-model"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
return model_instance
def test_embed_multimodal_query_cache_miss(self, mock_model_instance):
"""Test embedding multimodal query when Redis cache is empty."""
cache_embedding = CacheEmbedding(mock_model_instance, user="test-user")
document = {"file_id": "file123"}
vector = np.random.randn(1536)
normalized = (vector / np.linalg.norm(vector)).tolist()
usage = EmbeddingUsage(
tokens=5,
total_tokens=5,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal("0.0000005"),
currency="USD",
latency=0.3,
)
embedding_result = EmbeddingResult(
model="vision-embedding-model",
embeddings=[normalized],
usage=usage,
)
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
mock_redis.get.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_query(document)
assert isinstance(result, list)
assert len(result) == 1536
mock_redis.setex.assert_called_once()
def test_embed_multimodal_query_cache_hit(self, mock_model_instance):
"""Test embedding multimodal query when Redis cache has the value."""
cache_embedding = CacheEmbedding(mock_model_instance)
document = {"file_id": "file123"}
embedding_vector = np.random.randn(1536)
vector_bytes = embedding_vector.tobytes()
encoded_vector = base64.b64encode(vector_bytes).decode("utf-8")
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
mock_redis.get.return_value = encoded_vector.encode()
result = cache_embedding.embed_multimodal_query(document)
assert isinstance(result, list)
assert len(result) == 1536
mock_redis.expire.assert_called_once()
mock_model_instance.invoke_multimodal_embedding.assert_not_called()
def test_embed_multimodal_query_nan_handling(self, mock_model_instance):
"""Test handling of NaN values in multimodal query embeddings."""
cache_embedding = CacheEmbedding(mock_model_instance)
nan_vector = [float("nan")] * 1536
usage = EmbeddingUsage(
tokens=5,
total_tokens=5,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal("0.0000005"),
currency="USD",
latency=0.3,
)
embedding_result = EmbeddingResult(
model="vision-embedding-model",
embeddings=[nan_vector],
usage=usage,
)
document = {"file_id": "file123"}
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
mock_redis.get.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
with pytest.raises(ValueError) as exc_info:
cache_embedding.embed_multimodal_query(document)
assert "Normalized embedding is nan" in str(exc_info.value)
def test_embed_multimodal_query_api_error(self, mock_model_instance):
"""Test handling of API errors during multimodal query embedding."""
cache_embedding = CacheEmbedding(mock_model_instance)
document = {"file_id": "file123"}
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
mock_redis.get.return_value = None
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
mock_config.DEBUG = False
with pytest.raises(Exception) as exc_info:
cache_embedding.embed_multimodal_query(document)
assert "API Error" in str(exc_info.value)
def test_embed_multimodal_query_redis_set_error(self, mock_model_instance):
"""Test handling of Redis set errors during multimodal query embedding."""
cache_embedding = CacheEmbedding(mock_model_instance)
document = {"file_id": "file123"}
vector = np.random.randn(1536)
normalized = (vector / np.linalg.norm(vector)).tolist()
usage = EmbeddingUsage(
tokens=5,
total_tokens=5,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal("0.0000005"),
currency="USD",
latency=0.3,
)
embedding_result = EmbeddingResult(
model="vision-embedding-model",
embeddings=[normalized],
usage=usage,
)
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
mock_redis.get.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
mock_redis.setex.side_effect = RuntimeError("Redis Error")
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
mock_config.DEBUG = True
with pytest.raises(RuntimeError):
cache_embedding.embed_multimodal_query(document)
class TestCacheEmbeddingQueryErrors:
"""Test suite for error handling in CacheEmbedding.embed_query method."""
@pytest.fixture
def mock_model_instance(self):
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
return model_instance
def test_embed_query_api_error_debug_mode(self, mock_model_instance):
"""Test handling of API errors in debug mode."""
cache_embedding = CacheEmbedding(mock_model_instance)
query = "test query"
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
mock_redis.get.return_value = None
mock_model_instance.invoke_text_embedding.side_effect = RuntimeError("API Error")
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
mock_config.DEBUG = True
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
with pytest.raises(RuntimeError) as exc_info:
cache_embedding.embed_query(query)
assert "API Error" in str(exc_info.value)
mock_logger.exception.assert_called()
def test_embed_query_redis_set_error_debug_mode(self, mock_model_instance):
"""Test handling of Redis set errors in debug mode."""
cache_embedding = CacheEmbedding(mock_model_instance)
query = "test query"
vector = np.random.randn(1536)
normalized = (vector / np.linalg.norm(vector)).tolist()
usage = EmbeddingUsage(
tokens=5,
total_tokens=5,
unit_price=Decimal("0.0001"),
price_unit=Decimal(1000),
total_price=Decimal("0.0000005"),
currency="USD",
latency=0.3,
)
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
)
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
mock_redis.get.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
mock_redis.setex.side_effect = RuntimeError("Redis Error")
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
mock_config.DEBUG = True
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
with pytest.raises(RuntimeError):
cache_embedding.embed_query(query)
mock_logger.exception.assert_called()
class TestCacheEmbeddingInitialization:
"""Test suite for CacheEmbedding initialization."""
def test_initialization_with_user(self):
"""Test CacheEmbedding initialization with user parameter."""
model_instance = Mock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
cache_embedding = CacheEmbedding(model_instance, user="test-user")
assert cache_embedding._model_instance == model_instance
assert cache_embedding._user == "test-user"
def test_initialization_without_user(self):
"""Test CacheEmbedding initialization without user parameter."""
model_instance = Mock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
cache_embedding = CacheEmbedding(model_instance)
assert cache_embedding._model_instance == model_instance
assert cache_embedding._user is None

View File

@ -0,0 +1,220 @@
"""Unit tests for embedding_base.py - the abstract Embeddings base class."""
import asyncio
import inspect
from typing import Any
import pytest
from core.rag.embedding.embedding_base import Embeddings
class ConcreteEmbeddings(Embeddings):
"""Concrete implementation of Embeddings for testing."""
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [[1.0] * 10 for _ in texts]
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
return [[1.0] * 10 for _ in multimodel_documents]
def embed_query(self, text: str) -> list[float]:
return [1.0] * 10
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
return [1.0] * 10
class TestEmbeddingsBase:
"""Test suite for the abstract Embeddings base class."""
def test_embeddings_is_abc(self):
"""Test that Embeddings is an abstract base class."""
assert hasattr(Embeddings, "__abstractmethods__")
assert len(Embeddings.__abstractmethods__) > 0
def test_embed_documents_is_abstract(self):
"""Test that embed_documents is an abstract method."""
assert "embed_documents" in Embeddings.__abstractmethods__
def test_embed_multimodal_documents_is_abstract(self):
"""Test that embed_multimodal_documents is an abstract method."""
assert "embed_multimodal_documents" in Embeddings.__abstractmethods__
def test_embed_query_is_abstract(self):
"""Test that embed_query is an abstract method."""
assert "embed_query" in Embeddings.__abstractmethods__
def test_embed_multimodal_query_is_abstract(self):
"""Test that embed_multimodal_query is an abstract method."""
assert "embed_multimodal_query" in Embeddings.__abstractmethods__
def test_embed_documents_raises_not_implemented(self):
"""Test that embed_documents raises NotImplementedError in its body."""
source = inspect.getsource(Embeddings.embed_documents)
assert "raise NotImplementedError" in source
def test_embed_multimodal_documents_raises_not_implemented(self):
"""Test that embed_multimodal_documents raises NotImplementedError in its body."""
source = inspect.getsource(Embeddings.embed_multimodal_documents)
assert "raise NotImplementedError" in source
def test_embed_query_raises_not_implemented(self):
"""Test that embed_query raises NotImplementedError in its body."""
source = inspect.getsource(Embeddings.embed_query)
assert "raise NotImplementedError" in source
def test_embed_multimodal_query_raises_not_implemented(self):
"""Test that embed_multimodal_query raises NotImplementedError in its body."""
source = inspect.getsource(Embeddings.embed_multimodal_query)
assert "raise NotImplementedError" in source
def test_aembed_documents_raises_not_implemented(self):
"""Test that aembed_documents raises NotImplementedError in its body."""
source = inspect.getsource(Embeddings.aembed_documents)
assert "raise NotImplementedError" in source
def test_aembed_query_raises_not_implemented(self):
"""Test that aembed_query raises NotImplementedError in its body."""
source = inspect.getsource(Embeddings.aembed_query)
assert "raise NotImplementedError" in source
def test_concrete_implementation_works(self):
"""Test that a concrete implementation of Embeddings works correctly."""
concrete = ConcreteEmbeddings()
result = concrete.embed_documents(["test1", "test2"])
assert len(result) == 2
assert all(len(emb) == 10 for emb in result)
def test_concrete_implementation_embed_query(self):
"""Test concrete implementation of embed_query."""
concrete = ConcreteEmbeddings()
result = concrete.embed_query("test query")
assert len(result) == 10
def test_concrete_implementation_embed_multimodal_documents(self):
"""Test concrete implementation of embed_multimodal_documents."""
concrete = ConcreteEmbeddings()
docs: list[dict[str, Any]] = [{"file_id": "file1"}, {"file_id": "file2"}]
result = concrete.embed_multimodal_documents(docs)
assert len(result) == 2
def test_concrete_implementation_embed_multimodal_query(self):
"""Test concrete implementation of embed_multimodal_query."""
concrete = ConcreteEmbeddings()
result = concrete.embed_multimodal_query({"file_id": "test"})
assert len(result) == 10
class TestEmbeddingsNotImplemented:
"""Test that abstract methods raise NotImplementedError when called."""
def test_embed_query_raises_not_implemented(self):
"""Test that embed_query raises NotImplementedError."""
class PartialImpl:
pass
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
partial = PartialImpl()
with pytest.raises(NotImplementedError):
partial.embed_query("test")
def test_embed_documents_raises_not_implemented(self):
"""Test that embed_documents raises NotImplementedError."""
class PartialImpl:
pass
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
partial = PartialImpl()
with pytest.raises(NotImplementedError):
partial.embed_documents(["test"])
def test_embed_multimodal_documents_raises_not_implemented(self):
"""Test that embed_multimodal_documents raises NotImplementedError."""
class PartialImpl:
pass
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
partial = PartialImpl()
with pytest.raises(NotImplementedError):
partial.embed_multimodal_documents([{"file_id": "test"}])
def test_embed_multimodal_query_raises_not_implemented(self):
"""Test that embed_multimodal_query raises NotImplementedError."""
class PartialImpl:
pass
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
partial = PartialImpl()
with pytest.raises(NotImplementedError):
partial.embed_multimodal_query({"file_id": "test"})
def test_aembed_documents_raises_not_implemented(self):
"""Test that aembed_documents raises NotImplementedError."""
class PartialImpl:
pass
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
partial = PartialImpl()
async def run_test():
with pytest.raises(NotImplementedError):
await partial.aembed_documents(["test"])
asyncio.run(run_test())
def test_aembed_query_raises_not_implemented(self):
"""Test that aembed_query raises NotImplementedError."""
class PartialImpl:
pass
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
partial = PartialImpl()
async def run_test():
with pytest.raises(NotImplementedError):
await partial.aembed_query("test")
asyncio.run(run_test())

View File

@ -0,0 +1,85 @@
from io import BytesIO
import pytest
from core.rag.extractor.blob.blob import Blob
class TestBlob:
def test_requires_data_or_path(self):
with pytest.raises(ValueError, match="Either data or path must be provided"):
Blob()
def test_source_property_and_repr_include_path(self, tmp_path):
file_path = tmp_path / "sample.txt"
file_path.write_text("hello", encoding="utf-8")
blob = Blob.from_path(str(file_path))
assert blob.source == str(file_path)
assert str(file_path) in repr(blob)
def test_as_string_from_bytes_and_str(self):
assert Blob.from_data(b"abc").as_string() == "abc"
assert Blob.from_data("plain-text").as_string() == "plain-text"
def test_as_string_from_path(self, tmp_path):
file_path = tmp_path / "sample.txt"
file_path.write_text("from-file", encoding="utf-8")
blob = Blob.from_path(str(file_path))
assert blob.as_string() == "from-file"
def test_as_string_raises_for_invalid_state(self):
blob = Blob.model_construct(data=None, path=None, mimetype=None, encoding="utf-8")
with pytest.raises(ValueError, match="Unable to get string for blob"):
blob.as_string()
def test_as_bytes_from_bytes_str_and_path(self, tmp_path):
from_bytes = Blob.from_data(b"abc")
from_str = Blob.from_data("abc", encoding="utf-8")
file_path = tmp_path / "sample.bin"
file_path.write_bytes(b"from-path")
from_path = Blob.from_path(str(file_path))
assert from_bytes.as_bytes() == b"abc"
assert from_str.as_bytes() == b"abc"
assert from_path.as_bytes() == b"from-path"
def test_as_bytes_raises_for_invalid_state(self):
blob = Blob.model_construct(data=None, path=None, mimetype=None, encoding="utf-8")
with pytest.raises(ValueError, match="Unable to get bytes for blob"):
blob.as_bytes()
def test_as_bytes_io_for_bytes_and_path(self, tmp_path):
data_blob = Blob.from_data(b"bytes-io")
with data_blob.as_bytes_io() as stream:
assert isinstance(stream, BytesIO)
assert stream.read() == b"bytes-io"
file_path = tmp_path / "stream.bin"
file_path.write_bytes(b"path-stream")
path_blob = Blob.from_path(str(file_path))
with path_blob.as_bytes_io() as stream:
assert stream.read() == b"path-stream"
def test_as_bytes_io_raises_for_unsupported_data_type(self):
blob = Blob.from_data("text-value")
with pytest.raises(NotImplementedError, match="Unable to convert blob"):
with blob.as_bytes_io():
pass
def test_from_path_respects_guessing_and_explicit_mime(self, tmp_path):
file_path = tmp_path / "example.txt"
file_path.write_text("x", encoding="utf-8")
guessed = Blob.from_path(str(file_path))
explicit = Blob.from_path(str(file_path), mime_type="custom/type", guess_type=False)
assert guessed.mimetype == "text/plain"
assert explicit.mimetype == "custom/type"

View File

@ -1,61 +1,337 @@
import os
"""Unit tests for Firecrawl app and extractor integration points."""
import json
from collections.abc import Mapping
from typing import Any
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
import core.rag.extractor.firecrawl.firecrawl_app as firecrawl_module
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
params = {
"includePaths": [],
"excludePaths": [],
"maxDepth": 1,
"limit": 1,
}
mocked_firecrawl = {
"id": "test",
}
mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
assert job_id is not None
assert isinstance(job_id, str)
def _response(status_code: int, json_data: Mapping[str, Any] | None = None, text: str = "") -> MagicMock:
response = MagicMock()
response.status_code = status_code
response.text = text
response.json.return_value = json_data if json_data is not None else {}
return response
def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture):
api_key = "fc-"
base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"]
for base in base_urls:
app = FirecrawlApp(api_key=api_key, base_url=base)
mock_post = mocker.patch("httpx.post")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"id": "job123"}
mock_post.return_value = mock_resp
app.crawl_url("https://example.com", params=None)
called_url = mock_post.call_args[0][0]
assert called_url == "https://custom.firecrawl.dev/v2/crawl"
class TestFirecrawlApp:
def test_init_requires_api_key_for_default_base_url(self):
with pytest.raises(ValueError, match="No API key provided"):
FirecrawlApp(api_key=None, base_url="https://api.firecrawl.dev")
def test_prepare_headers_and_build_url(self):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev/")
assert app._prepare_headers() == {
"Content-Type": "application/json",
"Authorization": "Bearer fc-key",
}
assert app._build_url("/v2/crawl") == "https://custom.firecrawl.dev/v2/crawl"
def test_scrape_url_success(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch(
"httpx.post",
return_value=_response(
200,
{
"data": {
"metadata": {
"title": "t",
"description": "d",
"sourceURL": "https://example.com",
},
"markdown": "body",
}
},
),
)
result = app.scrape_url("https://example.com", params={"onlyMainContent": False})
assert result == {
"title": "t",
"description": "d",
"source_url": "https://example.com",
"markdown": "body",
}
def test_scrape_url_handles_known_error_status(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("boom"))
mocker.patch("httpx.post", return_value=_response(429, {"error": "limit"}))
with pytest.raises(Exception, match="boom"):
app.scrape_url("https://example.com")
mock_handle.assert_called_once()
def test_scrape_url_unknown_status_raises(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch("httpx.post", return_value=_response(404, text="Not Found"))
with pytest.raises(Exception, match="Failed to scrape URL. Status code: 404"):
app.scrape_url("https://example.com")
def test_crawl_url_success(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch("httpx.post", return_value=_response(200, {"id": "job-1"}))
assert app.crawl_url("https://example.com") == "job-1"
def test_crawl_url_non_200_uses_error_handler(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl failed"))
mocker.patch("httpx.post", return_value=_response(500, {"error": "server"}))
with pytest.raises(Exception, match="crawl failed"):
app.crawl_url("https://example.com")
mock_handle.assert_called_once()
def test_map_success(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch("httpx.post", return_value=_response(200, {"success": True, "links": ["a", "b"]}))
assert app.map("https://example.com") == {"success": True, "links": ["a", "b"]}
def test_map_known_error(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mock_handle = mocker.patch.object(app, "_handle_error")
mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"}))
assert app.map("https://example.com") == {}
mock_handle.assert_called_once()
def test_map_unknown_error_raises(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch("httpx.post", return_value=_response(418, text="teapot"))
with pytest.raises(Exception, match="Failed to start map job. Status code: 418"):
app.map("https://example.com")
def test_check_crawl_status_completed_with_data(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
payload = {
"status": "completed",
"total": 2,
"completed": 2,
"data": [
{
"metadata": {"title": "a", "description": "desc-a", "sourceURL": "https://a"},
"markdown": "m-a",
},
{
"metadata": {"title": "b", "description": "desc-b", "sourceURL": "https://b"},
"markdown": "m-b",
},
{"metadata": {"title": "skip"}},
],
}
mocker.patch("httpx.get", return_value=_response(200, payload))
save_calls: list[tuple[str, bytes]] = []
delete_calls: list[str] = []
mock_storage = MagicMock()
mock_storage.exists.return_value = True
mock_storage.delete.side_effect = lambda key: delete_calls.append(key)
mock_storage.save.side_effect = lambda key, data: save_calls.append((key, data))
mocker.patch.object(firecrawl_module, "storage", mock_storage)
result = app.check_crawl_status("job-42")
assert result["status"] == "completed"
assert result["total"] == 2
assert result["current"] == 2
assert len(result["data"]) == 2
assert delete_calls == ["website_files/job-42.txt"]
assert len(save_calls) == 1
assert save_calls[0][0] == "website_files/job-42.txt"
def test_check_crawl_status_completed_with_zero_total_raises(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": 0, "data": []}))
with pytest.raises(Exception, match="No page found"):
app.check_crawl_status("job-1")
def test_check_crawl_status_non_completed(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
payload = {"status": "processing", "total": 5, "completed": 1, "data": []}
mocker.patch("httpx.get", return_value=_response(200, payload))
assert app.check_crawl_status("job-1") == {
"status": "processing",
"total": 5,
"current": 1,
"data": [],
}
def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mock_handle = mocker.patch.object(app, "_handle_error")
mocker.patch("httpx.get", return_value=_response(500, {"error": "server"}))
assert app.check_crawl_status("job-1") == {}
mock_handle.assert_called_once()
def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
payload = {
"status": "completed",
"total": 1,
"completed": 1,
"data": [{"metadata": {"title": "a", "sourceURL": "https://a"}, "markdown": "m-a"}],
}
mocker.patch("httpx.get", return_value=_response(200, payload))
mock_storage = MagicMock()
mock_storage.exists.return_value = False
mock_storage.save.side_effect = RuntimeError("save failed")
mocker.patch.object(firecrawl_module, "storage", mock_storage)
with pytest.raises(Exception, match="Error saving crawl data"):
app.check_crawl_status("job-err")
def test_extract_common_fields_and_status_formatter(self):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
fields = app._extract_common_fields(
{"metadata": {"title": "t", "description": "d", "sourceURL": "u"}, "markdown": "m"}
)
assert fields == {"title": "t", "description": "d", "source_url": "u", "markdown": "m"}
status = app._format_crawl_status_response("completed", {"total": 1, "completed": 1}, [fields])
assert status == {"status": "completed", "total": 1, "current": 1, "data": [fields]}
def test_post_and_get_request_retry_logic(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
sleep_mock = mocker.patch.object(firecrawl_module.time, "sleep")
resp_502_a = _response(502)
resp_502_b = _response(502)
resp_200 = _response(200)
mocker.patch("httpx.post", side_effect=[resp_502_a, resp_200])
post_result = app._post_request("u", {"x": 1}, {"h": 1}, retries=3, backoff_factor=0.5)
assert post_result is resp_200
mocker.patch("httpx.get", side_effect=[resp_502_b, _response(200)])
get_result = app._get_request("u", {"h": 1}, retries=3, backoff_factor=0.25)
assert get_result.status_code == 200
assert sleep_mock.call_count == 2
def test_post_and_get_request_return_last_502(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
sleep_mock = mocker.patch.object(firecrawl_module.time, "sleep")
last_post = _response(502)
mocker.patch("httpx.post", side_effect=[_response(502), last_post])
assert app._post_request("u", {}, {}, retries=2).status_code == 502
last_get = _response(502)
mocker.patch("httpx.get", side_effect=[_response(502), last_get])
assert app._get_request("u", {}, retries=2).status_code == 502
assert sleep_mock.call_count == 4
def test_handle_error_with_json_and_plain_text(self):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
json_error = _response(400, {"message": "bad request"})
with pytest.raises(Exception, match="bad request"):
app._handle_error(json_error, "run task")
non_json = MagicMock()
non_json.status_code = 400
non_json.text = "plain error"
non_json.json.side_effect = json.JSONDecodeError("bad", "x", 0)
with pytest.raises(Exception, match="plain error"):
app._handle_error(non_json, "run task")
def test_search_success(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch("httpx.post", return_value=_response(200, {"success": True, "data": [{"url": "x"}]}))
assert app.search("python")["success"] is True
def test_search_warning_failure(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch("httpx.post", return_value=_response(200, {"success": False, "warning": "bad search"}))
with pytest.raises(Exception, match="bad search"):
app.search("python")
def test_search_known_http_error(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mock_handle = mocker.patch.object(app, "_handle_error")
mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"}))
assert app.search("python") == {}
mock_handle.assert_called_once()
def test_search_unknown_http_error(self, mocker: MockerFixture):
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
mocker.patch("httpx.post", return_value=_response(418, text="teapot"))
with pytest.raises(Exception, match="Failed to perform search. Status code: 418"):
app.search("python")
def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture):
api_key = "fc-"
app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/")
mock_post = mocker.patch("httpx.post")
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_resp
class TestFirecrawlWebExtractor:
def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture):
mocker.patch(
"core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_crawl_url_data",
return_value={
"markdown": "crawl content",
"source_url": "https://example.com",
"description": "desc",
"title": "title",
},
)
with pytest.raises(Exception) as excinfo:
app.scrape_url("https://example.com")
extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
docs = extractor.extract()
# Should not raise a JSONDecodeError; current behavior reports status code only
assert str(excinfo.value) == "Failed to scrape URL. Status code: 404"
assert len(docs) == 1
assert docs[0].page_content == "crawl content"
assert docs[0].metadata["source_url"] == "https://example.com"
def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture):
mocker.patch(
"core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_crawl_url_data",
return_value=None,
)
extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
assert extractor.extract() == []
def test_extract_scrape_mode_returns_document(self, mocker: MockerFixture):
mock_scrape = mocker.patch(
"core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_scrape_url_data",
return_value={
"markdown": "scrape content",
"source_url": "https://example.com",
"description": "desc",
"title": "title",
},
)
extractor = FirecrawlWebExtractor(
"https://example.com", "job-1", "tenant-1", mode="scrape", only_main_content=False
)
docs = extractor.extract()
assert len(docs) == 1
assert docs[0].page_content == "scrape content"
mock_scrape.assert_called_once_with("firecrawl", "https://example.com", "tenant-1", False)
def test_extract_unknown_mode_returns_empty(self):
extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="unknown")
assert extractor.extract() == []

View File

@ -0,0 +1,95 @@
import csv
import io
from types import SimpleNamespace
import pandas as pd
import pytest
import core.rag.extractor.csv_extractor as csv_module
from core.rag.extractor.csv_extractor import CSVExtractor
class _ManagedStringIO(io.StringIO):
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
return False
class TestCSVExtractor:
def test_extract_success_with_source_column(self, tmp_path):
file_path = tmp_path / "data.csv"
file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8")
extractor = CSVExtractor(str(file_path), source_column="id")
docs = extractor.extract()
assert len(docs) == 1
assert docs[0].page_content == "id: source-1;body: hello"
assert docs[0].metadata == {"source": "source-1", "row": 0}
def test_extract_raises_when_source_column_missing(self, tmp_path):
file_path = tmp_path / "data.csv"
file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8")
extractor = CSVExtractor(str(file_path), source_column="missing_col")
with pytest.raises(ValueError, match="Source column 'missing_col' not found"):
extractor.extract()
def test_extract_wraps_unicode_error_when_autodetect_disabled(self, monkeypatch):
extractor = CSVExtractor("dummy.csv", autodetect_encoding=False)
def raise_decode(*args, **kwargs):
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error")
monkeypatch.setattr("builtins.open", raise_decode)
with pytest.raises(RuntimeError, match="Error loading dummy.csv"):
extractor.extract()
def test_extract_autodetect_encoding_success(self, monkeypatch):
extractor = CSVExtractor("dummy.csv", autodetect_encoding=True)
attempted_encodings: list[str | None] = []
def fake_open(path, newline="", encoding=None):
attempted_encodings.append(encoding)
if encoding is None:
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error")
if encoding == "bad":
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error")
return _ManagedStringIO("id,body\nsource-1,hello\n")
monkeypatch.setattr("builtins.open", fake_open)
monkeypatch.setattr(
csv_module,
"detect_file_encodings",
lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")],
)
docs = extractor.extract()
assert len(docs) == 1
assert docs[0].page_content == "id: source-1;body: hello"
assert attempted_encodings == [None, "bad", "utf-8"]
def test_extract_autodetect_encoding_all_attempts_fail_returns_empty(self, monkeypatch):
extractor = CSVExtractor("dummy.csv", autodetect_encoding=True)
def always_raise(*args, **kwargs):
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error")
monkeypatch.setattr("builtins.open", always_raise)
monkeypatch.setattr(csv_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="bad")])
assert extractor.extract() == []
def test_read_from_file_re_raises_csv_error(self, monkeypatch):
extractor = CSVExtractor("dummy.csv")
monkeypatch.setattr(pd, "read_csv", lambda *args, **kwargs: (_ for _ in ()).throw(csv.Error("bad csv")))
with pytest.raises(csv.Error, match="bad csv"):
extractor._read_from_file(io.StringIO("x"))

View File

@ -0,0 +1,117 @@
from types import SimpleNamespace
import pandas as pd
import pytest
import core.rag.extractor.excel_extractor as excel_module
from core.rag.extractor.excel_extractor import ExcelExtractor
class _FakeCell:
def __init__(self, value, hyperlink=None):
self.value = value
self.hyperlink = hyperlink
class _FakeSheet:
def __init__(self, header_rows, data_rows):
self._header_rows = header_rows
self._data_rows = data_rows
def iter_rows(self, min_row=1, max_row=None, max_col=None, values_only=False):
if values_only:
for row in self._header_rows:
yield tuple(row)
return
for row in self._data_rows:
if max_col is not None:
yield tuple(row[:max_col])
else:
yield tuple(row)
class _FakeWorkbook:
def __init__(self, sheets):
self._sheets = sheets
self.sheetnames = list(sheets.keys())
self.closed = False
def __getitem__(self, key):
return self._sheets[key]
def close(self):
self.closed = True
class TestExcelExtractor:
def test_extract_xlsx_with_hyperlinks_and_sheet_skip(self, monkeypatch):
sheet_with_data = _FakeSheet(
header_rows=[("Name", "Link")],
data_rows=[
(_FakeCell("Alice"), _FakeCell("Doc", hyperlink=SimpleNamespace(target="https://example.com/doc"))),
(_FakeCell(None), _FakeCell(123)),
(_FakeCell(None), _FakeCell(None)),
],
)
empty_sheet = _FakeSheet(header_rows=[(None, None)], data_rows=[])
workbook = _FakeWorkbook({"Data": sheet_with_data, "Empty": empty_sheet})
monkeypatch.setattr(excel_module, "load_workbook", lambda *args, **kwargs: workbook)
extractor = ExcelExtractor("/tmp/sample.xlsx")
docs = extractor.extract()
assert workbook.closed is True
assert len(docs) == 2
assert docs[0].page_content == '"Name":"Alice";"Link":"[Doc](https://example.com/doc)"'
assert docs[1].page_content == '"Name":"";"Link":"123"'
assert all(doc.metadata["source"] == "/tmp/sample.xlsx" for doc in docs)
def test_extract_xls_path(self, monkeypatch):
class FakeExcelFile:
sheet_names = ["Sheet1"]
def parse(self, sheet_name):
assert sheet_name == "Sheet1"
return pd.DataFrame([{"A": "x", "B": 1}, {"A": None, "B": None}])
monkeypatch.setattr(pd, "ExcelFile", lambda path, engine=None: FakeExcelFile())
extractor = ExcelExtractor("/tmp/sample.xls")
docs = extractor.extract()
assert len(docs) == 1
assert docs[0].page_content == '"A":"x";"B":"1.0"'
assert docs[0].metadata == {"source": "/tmp/sample.xls"}
def test_extract_unsupported_extension_raises(self):
extractor = ExcelExtractor("/tmp/sample.txt")
with pytest.raises(ValueError, match="Unsupported file extension"):
extractor.extract()
def test_find_header_and_columns_prefers_first_row_with_two_columns(self):
sheet = _FakeSheet(
header_rows=[(None, None, None), ("A", "B", None), ("X", None, None)],
data_rows=[],
)
extractor = ExcelExtractor("dummy.xlsx")
header_row_idx, column_map, max_col_idx = extractor._find_header_and_columns(sheet)
assert header_row_idx == 2
assert column_map == {0: "A", 1: "B"}
assert max_col_idx == 2
def test_find_header_and_columns_fallback_and_empty_case(self):
extractor = ExcelExtractor("dummy.xlsx")
fallback_sheet = _FakeSheet(header_rows=[("Only", None), (None, "Second")], data_rows=[])
row_idx, column_map, max_col_idx = extractor._find_header_and_columns(fallback_sheet)
assert row_idx == 1
assert column_map == {0: "Only"}
assert max_col_idx == 1
empty_sheet = _FakeSheet(header_rows=[(None, None)], data_rows=[])
assert extractor._find_header_and_columns(empty_sheet) == (0, {}, 0)

View File

@ -0,0 +1,272 @@
from pathlib import Path
from types import SimpleNamespace
import pytest
import core.rag.extractor.extract_processor as processor_module
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
class _ExtractorFactory:
def __init__(self) -> None:
self.calls = []
def make(self, name: str) -> type[object]:
calls = self.calls
class DummyExtractor:
def __init__(self, *args, **kwargs):
calls.append((name, args, kwargs))
def extract(self):
return [Document(page_content=f"extracted-by-{name}")]
return DummyExtractor
def _patch_all_extractors(monkeypatch) -> _ExtractorFactory:
factory = _ExtractorFactory()
for cls_name in [
"CSVExtractor",
"ExcelExtractor",
"FirecrawlWebExtractor",
"HtmlExtractor",
"JinaReaderWebExtractor",
"MarkdownExtractor",
"NotionExtractor",
"PdfExtractor",
"TextExtractor",
"UnstructuredEmailExtractor",
"UnstructuredEpubExtractor",
"UnstructuredMarkdownExtractor",
"UnstructuredMsgExtractor",
"UnstructuredPPTExtractor",
"UnstructuredPPTXExtractor",
"UnstructuredWordExtractor",
"UnstructuredXmlExtractor",
"WaterCrawlWebExtractor",
"WordExtractor",
]:
monkeypatch.setattr(processor_module, cls_name, factory.make(cls_name))
return factory
class TestExtractProcessorLoaders:
def test_load_from_upload_file_return_docs_and_text(self, monkeypatch):
monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs))
monkeypatch.setattr(
ExtractProcessor,
"extract",
lambda extract_setting, is_automatic=False, file_path=None: [
Document(page_content="doc-1"),
Document(page_content="doc-2"),
],
)
upload_file = SimpleNamespace(key="file.txt")
docs = ExtractProcessor.load_from_upload_file(upload_file=upload_file, return_text=False)
text = ExtractProcessor.load_from_upload_file(upload_file=upload_file, return_text=True)
assert len(docs) == 2
assert text == "doc-1\ndoc-2"
@pytest.mark.parametrize(
("url", "headers", "expected_suffix"),
[
("https://example.com/file.txt", {"Content-Type": "text/plain"}, ".txt"),
("https://example.com/no_suffix", {"Content-Type": "application/pdf"}, ".pdf"),
(
"https://example.com/no_suffix",
{"Content-Disposition": 'attachment; filename="report.md"'},
".md",
),
(
"https://example.com/no_suffix",
{"Content-Disposition": 'attachment; filename="report"'},
"",
),
],
)
def test_load_from_url_builds_temp_file_with_correct_suffix(self, monkeypatch, url, headers, expected_suffix):
response = SimpleNamespace(headers=headers, content=b"body")
monkeypatch.setattr(processor_module.ssrf_proxy, "get", lambda *args, **kwargs: response)
monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs))
captured = {}
def fake_extract(extract_setting, is_automatic=False, file_path=None):
key = "file_path_docs" if "file_path_docs" not in captured else "file_path_text"
captured[key] = file_path
return [Document(page_content="u1"), Document(page_content="u2")]
monkeypatch.setattr(ExtractProcessor, "extract", fake_extract)
docs = ExtractProcessor.load_from_url(url, return_text=False)
assert captured["file_path_docs"].endswith(expected_suffix)
text = ExtractProcessor.load_from_url(url, return_text=True)
assert captured["file_path_text"].endswith(expected_suffix)
assert len(docs) == 2
assert text == "u1\nu2"
class TestExtractProcessorFileRouting:
@pytest.fixture(autouse=True)
def _set_unstructured_config(self, monkeypatch):
monkeypatch.setattr(processor_module.dify_config, "UNSTRUCTURED_API_URL", "https://unstructured")
monkeypatch.setattr(processor_module.dify_config, "UNSTRUCTURED_API_KEY", "key")
def _run_extract_for_extension(self, monkeypatch, extension: str, etl_type: str, is_automatic: bool = False):
factory = _patch_all_extractors(monkeypatch)
monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", etl_type)
def fake_download(key: str, local_path: str):
Path(local_path).write_text("content", encoding="utf-8")
monkeypatch.setattr(processor_module.storage, "download", fake_download)
monkeypatch.setattr(processor_module.tempfile, "_get_candidate_names", lambda: iter(["candidate-name"]))
setting = SimpleNamespace(
datasource_type=DatasourceType.FILE,
upload_file=SimpleNamespace(key=f"uploaded{extension}", tenant_id="tenant-1", created_by="user-1"),
)
docs = ExtractProcessor.extract(setting, is_automatic=is_automatic)
assert len(docs) == 1
assert docs[0].page_content.startswith("extracted-by-")
return factory.calls[-1][0], factory.calls[-1][1], factory.calls[-1][2]
@pytest.mark.parametrize(
("extension", "expected_extractor", "is_automatic"),
[
(".xlsx", "ExcelExtractor", False),
(".xls", "ExcelExtractor", False),
(".pdf", "PdfExtractor", False),
(".md", "UnstructuredMarkdownExtractor", True),
(".mdx", "MarkdownExtractor", False),
(".htm", "HtmlExtractor", False),
(".html", "HtmlExtractor", False),
(".docx", "WordExtractor", False),
(".doc", "UnstructuredWordExtractor", False),
(".csv", "CSVExtractor", False),
(".msg", "UnstructuredMsgExtractor", False),
(".eml", "UnstructuredEmailExtractor", False),
(".ppt", "UnstructuredPPTExtractor", False),
(".pptx", "UnstructuredPPTXExtractor", False),
(".xml", "UnstructuredXmlExtractor", False),
(".epub", "UnstructuredEpubExtractor", False),
(".txt", "TextExtractor", False),
],
)
def test_extract_routes_file_extensions_for_unstructured_mode(
self, monkeypatch, extension, expected_extractor, is_automatic
):
extractor_name, args, kwargs = self._run_extract_for_extension(
monkeypatch, extension, etl_type="Unstructured", is_automatic=is_automatic
)
assert extractor_name == expected_extractor
assert args
@pytest.mark.parametrize(
("extension", "expected_extractor"),
[
(".xlsx", "ExcelExtractor"),
(".pdf", "PdfExtractor"),
(".markdown", "MarkdownExtractor"),
(".html", "HtmlExtractor"),
(".docx", "WordExtractor"),
(".csv", "CSVExtractor"),
(".epub", "UnstructuredEpubExtractor"),
(".txt", "TextExtractor"),
],
)
def test_extract_routes_file_extensions_for_default_mode(self, monkeypatch, extension, expected_extractor):
extractor_name, _, _ = self._run_extract_for_extension(monkeypatch, extension, etl_type="SelfHosted")
assert extractor_name == expected_extractor
def test_extract_requires_upload_file_when_file_path_not_provided(self):
setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None)
with pytest.raises(AssertionError, match="upload_file is required"):
ExtractProcessor.extract(setting)
class TestExtractProcessorDatasourceRouting:
def test_extract_routes_notion_datasource(self, monkeypatch):
factory = _patch_all_extractors(monkeypatch)
notion_info = SimpleNamespace(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
document="doc",
tenant_id="tenant",
credential_id="cred",
)
setting = SimpleNamespace(datasource_type=DatasourceType.NOTION, notion_info=notion_info)
docs = ExtractProcessor.extract(setting)
assert docs[0].page_content == "extracted-by-NotionExtractor"
assert factory.calls[-1][0] == "NotionExtractor"
@pytest.mark.parametrize(
("provider", "expected"),
[
("firecrawl", "FirecrawlWebExtractor"),
("watercrawl", "WaterCrawlWebExtractor"),
("jinareader", "JinaReaderWebExtractor"),
],
)
def test_extract_routes_website_datasource_providers(self, monkeypatch, provider: str, expected: str):
factory = _patch_all_extractors(monkeypatch)
website_info = SimpleNamespace(
provider=provider,
url="https://example.com",
job_id="job",
tenant_id="tenant",
mode="crawl",
only_main_content=True,
)
setting = SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=website_info)
docs = ExtractProcessor.extract(setting)
assert docs[0].page_content == f"extracted-by-{expected}"
assert factory.calls[-1][0] == expected
def test_extract_unsupported_website_provider(self):
bad_provider = SimpleNamespace(
provider="unknown",
url="https://example.com",
job_id="job",
tenant_id="tenant",
mode="crawl",
only_main_content=True,
)
setting = SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=bad_provider)
with pytest.raises(ValueError, match="Unsupported website provider"):
ExtractProcessor.extract(setting)
def test_extract_unsupported_datasource_type(self):
with pytest.raises(ValueError, match="Unsupported datasource type"):
ExtractProcessor.extract(SimpleNamespace(datasource_type="unknown"))
def test_extract_requires_notion_info(self):
with pytest.raises(AssertionError, match="notion_info is required"):
ExtractProcessor.extract(SimpleNamespace(datasource_type=DatasourceType.NOTION, notion_info=None))
def test_extract_requires_website_info(self):
with pytest.raises(AssertionError, match="website_info is required"):
ExtractProcessor.extract(SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=None))

View File

@ -0,0 +1,26 @@
import pytest
from core.rag.extractor.extractor_base import BaseExtractor
class _CallsBaseExtractor(BaseExtractor):
def extract(self):
return super().extract()
class _ConcreteExtractor(BaseExtractor):
def extract(self):
return ["ok"]
class TestBaseExtractor:
def test_extract_default_raises_not_implemented(self):
extractor = _CallsBaseExtractor()
with pytest.raises(NotImplementedError):
extractor.extract()
def test_concrete_extractor_can_override(self):
extractor = _ConcreteExtractor()
assert extractor.extract() == ["ok"]

View File

@ -1,10 +1,55 @@
import tempfile
from types import SimpleNamespace
from core.rag.extractor.helpers import FileEncoding, detect_file_encodings
import pytest
from core.rag.extractor import helpers
from core.rag.extractor.helpers import detect_file_encodings
def test_detect_file_encodings() -> None:
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp:
temp.write("Shared data")
temp_path = temp.name
assert detect_file_encodings(temp_path) == [FileEncoding(encoding="utf_8", confidence=0.0, language="Unknown")]
class TestHelpers:
def test_detect_file_encodings(self) -> None:
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp:
temp.write("Shared data")
temp.flush()
temp_path = temp.name
encodings = detect_file_encodings(temp_path)
assert len(encodings) == 1
assert encodings[0].encoding in {"utf_8", "ascii"}
assert encodings[0].confidence == 0.0
# Assert the language field for full coverage
assert encodings[0].language is not None
def test_detect_file_encodings_timeout(self, monkeypatch):
class FakeFuture:
def result(self, timeout=None):
raise helpers.concurrent.futures.TimeoutError()
class FakeExecutor:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def submit(self, fn, file_path):
return FakeFuture()
monkeypatch.setattr(helpers.concurrent.futures, "ThreadPoolExecutor", lambda: FakeExecutor())
with pytest.raises(TimeoutError, match="Timeout reached while detecting encoding"):
detect_file_encodings("file.txt", timeout=1)
def test_detect_file_encodings_raises_when_encoding_not_detected(self, monkeypatch):
class FakeResult:
encoding = None
coherence = 0.0
language = None
monkeypatch.setattr(
helpers.charset_normalizer, "from_path", lambda _: SimpleNamespace(best=lambda: FakeResult())
)
with pytest.raises(RuntimeError, match="Could not detect encoding"):
detect_file_encodings("file.txt")

View File

@ -0,0 +1,21 @@
from core.rag.extractor.html_extractor import HtmlExtractor
class TestHtmlExtractor:
def test_extract_returns_text_content(self, tmp_path):
file_path = tmp_path / "sample.html"
file_path.write_text("<html><body><h1>Title</h1><p>Hello</p></body></html>", encoding="utf-8")
extractor = HtmlExtractor(str(file_path))
docs = extractor.extract()
assert len(docs) == 1
assert "".join(docs[0].page_content.split()) == "TitleHello"
def test_load_as_text_strips_whitespace_and_handles_empty(self, tmp_path):
file_path = tmp_path / "sample.html"
file_path.write_text("<html><body> \n </body></html>", encoding="utf-8")
extractor = HtmlExtractor(str(file_path))
assert extractor._load_as_text() == ""

View File

@ -0,0 +1,47 @@
from pytest_mock import MockerFixture
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
class TestJinaReaderWebExtractor:
def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture):
mocker.patch(
"core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data",
return_value={
"content": "markdown-content",
"url": "https://example.com",
"description": "desc",
"title": "title",
},
)
extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
docs = extractor.extract()
assert len(docs) == 1
assert docs[0].page_content == "markdown-content"
assert docs[0].metadata == {
"source_url": "https://example.com",
"description": "desc",
"title": "title",
}
def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture):
mocker.patch(
"core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data",
return_value=None,
)
extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
assert extractor.extract() == []
def test_extract_non_crawl_mode_returns_empty(self, mocker: MockerFixture):
mock_get_crawl = mocker.patch(
"core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data",
return_value={"content": "unused"},
)
extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape")
assert extractor.extract() == []
mock_get_crawl.assert_not_called()

View File

@ -1,8 +1,15 @@
from pathlib import Path
from types import SimpleNamespace
import pytest
import core.rag.extractor.markdown_extractor as markdown_module
from core.rag.extractor.markdown_extractor import MarkdownExtractor
def test_markdown_to_tups():
markdown = """
class TestMarkdownExtractor:
def test_markdown_to_tups(self):
markdown = """
this is some text without header
# title 1
@ -11,12 +18,113 @@ this is balabala text
## title 2
this is more specific text.
"""
extractor = MarkdownExtractor(file_path="dummy_path")
updated_output = extractor.markdown_to_tups(markdown)
assert len(updated_output) == 3
key, header_value = updated_output[0]
assert key == None
assert header_value.strip() == "this is some text without header"
title_1, value = updated_output[1]
assert title_1.strip() == "title 1"
assert value.strip() == "this is balabala text"
extractor = MarkdownExtractor(file_path="dummy_path")
updated_output = extractor.markdown_to_tups(markdown)
assert len(updated_output) == 3
key, header_value = updated_output[0]
assert key is None
assert header_value.strip() == "this is some text without header"
title_1, value = updated_output[1]
assert title_1.strip() == "title 1"
assert value.strip() == "this is balabala text"
def test_markdown_to_tups_keeps_code_block_headers_literal(self):
markdown = """# Header
before
```python
# this is not a heading
print('x')
```
after
"""
extractor = MarkdownExtractor(file_path="dummy_path")
tups = extractor.markdown_to_tups(markdown)
assert len(tups) == 2
assert tups[1][0] == "Header"
assert "# this is not a heading" in tups[1][1]
def test_remove_images_and_hyperlinks(self):
extractor = MarkdownExtractor(file_path="dummy_path")
with_images = "before ![[image.png]] after"
with_links = "[OpenAI](https://openai.com)"
assert extractor.remove_images(with_images) == "before after"
assert extractor.remove_hyperlinks(with_links) == "OpenAI"
def test_parse_tups_reads_file_and_applies_options(self, tmp_path):
markdown_file = tmp_path / "doc.md"
markdown_file.write_text("# Header\nText with [link](https://example.com) and ![[img.png]]", encoding="utf-8")
extractor = MarkdownExtractor(
file_path=str(markdown_file),
remove_hyperlinks=True,
remove_images=True,
autodetect_encoding=False,
)
tups = extractor.parse_tups(str(markdown_file))
assert len(tups) == 2
assert tups[1][0] == "Header"
assert "[link]" not in tups[1][1]
assert "img.png" not in tups[1][1]
def test_parse_tups_autodetects_encoding_after_decode_error(self, monkeypatch):
extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=True)
calls: list[str | None] = []
def fake_read_text(self, encoding=None):
calls.append(encoding)
if encoding is None:
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail")
if encoding == "bad-encoding":
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail")
return "# H\ncontent"
monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True)
monkeypatch.setattr(
markdown_module,
"detect_file_encodings",
lambda _: [SimpleNamespace(encoding="bad-encoding"), SimpleNamespace(encoding="utf-8")],
)
tups = extractor.parse_tups("dummy_path")
assert len(tups) == 2
assert calls == [None, "bad-encoding", "utf-8"]
def test_parse_tups_decode_error_with_autodetect_disabled_raises(self, monkeypatch):
extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=False)
def raise_decode(self, encoding=None):
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail")
monkeypatch.setattr(Path, "read_text", raise_decode, raising=True)
with pytest.raises(RuntimeError, match="Error loading dummy_path"):
extractor.parse_tups("dummy_path")
def test_parse_tups_other_exceptions_are_wrapped(self, monkeypatch):
extractor = MarkdownExtractor(file_path="dummy_path")
def raise_other(self, encoding=None):
raise OSError("disk error")
monkeypatch.setattr(Path, "read_text", raise_other, raising=True)
with pytest.raises(RuntimeError, match="Error loading dummy_path"):
extractor.parse_tups("dummy_path")
def test_extract_builds_documents_for_header_and_non_header(self, monkeypatch):
extractor = MarkdownExtractor(file_path="dummy_path")
monkeypatch.setattr(extractor, "parse_tups", lambda _: [(None, "plain"), ("Header", "value")])
docs = extractor.extract()
assert [doc.page_content for doc in docs] == ["plain", "\n\nHeader\nvalue"]

View File

@ -1,93 +1,499 @@
from types import SimpleNamespace
from unittest import mock
import httpx
import pytest
from pytest_mock import MockerFixture
from core.rag.extractor import notion_extractor
user_id = "user1"
database_id = "database1"
page_id = "page1"
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x"
)
def _generate_page(page_title: str):
return {
"object": "page",
"id": page_id,
"properties": {
"Page": {
"type": "title",
"title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}],
}
},
}
def _generate_block(block_id: str, block_type: str, block_text: str):
return {
"object": "block",
"id": block_id,
"parent": {"type": "page_id", "page_id": page_id},
"type": block_type,
"has_children": False,
block_type: {
"rich_text": [
{
"type": "text",
"text": {"content": block_text},
"plain_text": block_text,
}
]
},
}
def _mock_response(data):
def _mock_response(data, status_code: int = 200, text: str = ""):
response = mock.Mock()
response.status_code = 200
response.status_code = status_code
response.text = text
response.json.return_value = data
return response
def _remove_multiple_new_lines(text):
while "\n\n" in text:
text = text.replace("\n\n", "\n")
return text.strip()
class TestNotionExtractorInitAndPublicMethods:
def test_init_with_explicit_token(self):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
assert extractor._notion_access_token == "token"
def test_init_falls_back_to_env_token_when_credential_lookup_fails(self, monkeypatch):
monkeypatch.setattr(
notion_extractor.NotionExtractor,
"_get_access_token",
classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))),
)
monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", "env-token", raising=False)
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
credential_id="cred",
)
assert extractor._notion_access_token == "env-token"
def test_init_raises_if_no_credential_and_no_env_token(self, monkeypatch):
monkeypatch.setattr(
notion_extractor.NotionExtractor,
"_get_access_token",
classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))),
)
monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", None, raising=False)
with pytest.raises(ValueError, match="Must specify `integration_token`"):
notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
credential_id="cred",
)
def test_extract_updates_last_edited_and_loads_documents(self, monkeypatch):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
update_mock = mock.Mock()
load_mock = mock.Mock(return_value=[SimpleNamespace(page_content="doc")])
monkeypatch.setattr(extractor, "update_last_edited_time", update_mock)
monkeypatch.setattr(extractor, "_load_data_as_documents", load_mock)
docs = extractor.extract()
update_mock.assert_called_once_with(None)
load_mock.assert_called_once_with("obj", "page")
assert len(docs) == 1
def test_load_data_as_documents_page_database_and_invalid(self, monkeypatch):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
monkeypatch.setattr(extractor, "_get_notion_block_data", lambda _: ["line1", "line2"])
page_docs = extractor._load_data_as_documents("page-id", "page")
assert page_docs[0].page_content == "line1\nline2"
monkeypatch.setattr(extractor, "_get_notion_database_data", lambda _: [SimpleNamespace(page_content="db")])
db_docs = extractor._load_data_as_documents("db-id", "database")
assert db_docs[0].page_content == "db"
with pytest.raises(ValueError, match="notion page type not supported"):
extractor._load_data_as_documents("obj", "unsupported")
def test_notion_page(mocker: MockerFixture):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
"results": [
_generate_block("b1", "heading_1", texts[0]),
_generate_block("b2", "heading_2", texts[1]),
_generate_block("b3", "paragraph", texts[2]),
_generate_block("b4", "heading_3", texts[3]),
],
"next_cursor": None,
}
mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page))
class TestNotionDatabase:
def test_get_notion_database_data_parses_property_types_and_pagination(self, mocker: MockerFixture):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="database",
tenant_id="tenant",
notion_access_token="token",
)
page_docs = extractor._load_data_as_documents(page_id, "page")
assert len(page_docs) == 1
content = _remove_multiple_new_lines(page_docs[0].page_content)
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
first_page = {
"results": [
{
"properties": {
"tags": {
"type": "multi_select",
"multi_select": [{"name": "A"}, {"name": "B"}],
},
"title_prop": {"type": "title", "title": [{"plain_text": "Title"}]},
"empty_title": {"type": "title", "title": []},
"rich": {"type": "rich_text", "rich_text": [{"plain_text": "RichText"}]},
"empty_rich": {"type": "rich_text", "rich_text": []},
"select_prop": {"type": "select", "select": {"name": "Selected"}},
"empty_select": {"type": "select", "select": None},
"status_prop": {"type": "status", "status": {"name": "Open"}},
"empty_status": {"type": "status", "status": None},
"number_prop": {"type": "number", "number": 10},
"dict_prop": {"type": "date", "date": {"start": "2024-01-01", "end": None}},
},
"url": "https://notion.so/page-1",
}
],
"has_more": True,
"next_cursor": "cursor-2",
}
second_page = {"results": [], "has_more": False, "next_cursor": None}
mock_post = mocker.patch("httpx.post", side_effect=[_mock_response(first_page), _mock_response(second_page)])
docs = extractor._get_notion_database_data("db-1", query_dict={"filter": {"x": 1}})
assert len(docs) == 1
content = docs[0].page_content
assert "tags:['A', 'B']" in content
assert "title_prop:Title" in content
assert "rich:RichText" in content
assert "number_prop:10" in content
assert "dict_prop:start:2024-01-01" in content
assert "Row Page URL:https://notion.so/page-1" in content
assert mock_post.call_count == 2
def test_get_notion_database_data_handles_missing_results_and_empty_content(self, mocker: MockerFixture):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="database",
tenant_id="tenant",
notion_access_token="token",
)
mocker.patch("httpx.post", return_value=_mock_response({"results": None}))
assert extractor._get_notion_database_data("db-1") == []
def test_get_notion_database_data_requires_access_token(self):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="database",
tenant_id="tenant",
notion_access_token="token",
)
extractor._notion_access_token = None
with pytest.raises(AssertionError, match="Notion access token is required"):
extractor._get_notion_database_data("db-1")
def test_notion_database(mocker: MockerFixture):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",
"results": [_generate_page(i) for i in page_title_list],
"next_cursor": None,
}
mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database))
database_docs = extractor._load_data_as_documents(database_id, "database")
assert len(database_docs) == 1
content = _remove_multiple_new_lines(database_docs[0].page_content)
assert content == "\n".join([f"Page:{i}" for i in page_title_list])
class TestNotionBlocks:
def test_get_notion_block_data_success_with_table_headings_children_and_pagination(self, mocker: MockerFixture):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
first_response = {
"results": [
{"type": "table", "id": "tbl-1", "has_children": False, "table": {}},
{
"type": "heading_1",
"id": "h1",
"has_children": False,
"heading_1": {"rich_text": [{"text": {"content": "Heading"}}]},
},
{
"type": "paragraph",
"id": "p1",
"has_children": True,
"paragraph": {"rich_text": [{"text": {"content": "Paragraph"}}]},
},
{
"type": "child_page",
"id": "cp1",
"has_children": True,
"child_page": {"rich_text": []},
},
],
"next_cursor": "cursor-2",
}
second_response = {
"results": [
{
"type": "heading_2",
"id": "h2",
"has_children": False,
"heading_2": {"rich_text": [{"text": {"content": "SubHeading"}}]},
}
],
"next_cursor": None,
}
mocker.patch("httpx.request", side_effect=[_mock_response(first_response), _mock_response(second_response)])
mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE")
mocker.patch.object(extractor, "_read_block", return_value="CHILD")
lines = extractor._get_notion_block_data("page-1")
assert lines[0] == "TABLE\n\n"
assert "# Heading" in lines[1]
assert "Paragraph\nCHILD\n\n" in lines[2]
assert "## SubHeading" in lines[-1]
def test_get_notion_block_data_handles_http_error_and_invalid_payload(self, mocker: MockerFixture):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
mocker.patch("httpx.request", side_effect=httpx.HTTPError("network"))
with pytest.raises(ValueError, match="Error fetching Notion block data"):
extractor._get_notion_block_data("page-1")
mocker.patch("httpx.request", return_value=_mock_response({"bad": "payload"}, status_code=200))
with pytest.raises(ValueError, match="Error fetching Notion block data"):
extractor._get_notion_block_data("page-1")
mocker.patch("httpx.request", return_value=_mock_response({"results": []}, status_code=500, text="boom"))
with pytest.raises(ValueError, match="Error fetching Notion block data: boom"):
extractor._get_notion_block_data("page-1")
def test_read_block_supports_heading_table_and_recursion(self, mocker: MockerFixture):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
root_payload = {
"results": [
{
"type": "heading_2",
"id": "h2",
"has_children": False,
"heading_2": {"rich_text": [{"text": {"content": "Root"}}]},
},
{
"type": "paragraph",
"id": "child-block",
"has_children": True,
"paragraph": {"rich_text": [{"text": {"content": "Parent"}}]},
},
{"type": "table", "id": "tbl-1", "has_children": False, "table": {}},
],
"next_cursor": None,
}
child_payload = {
"results": [
{
"type": "paragraph",
"id": "leaf",
"has_children": False,
"paragraph": {"rich_text": [{"text": {"content": "Child"}}]},
}
],
"next_cursor": None,
}
mocker.patch("httpx.request", side_effect=[_mock_response(root_payload), _mock_response(child_payload)])
mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE-MD")
content = extractor._read_block("root")
assert "## Root" in content
assert "Parent" in content
assert "Child" in content
assert "TABLE-MD" in content
def test_read_block_breaks_on_missing_results(self, mocker: MockerFixture):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
mocker.patch("httpx.request", return_value=_mock_response({"results": None, "next_cursor": None}))
assert extractor._read_block("root") == ""
def test_read_table_rows_formats_markdown_with_pagination(self, mocker: MockerFixture):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
page_one = {
"results": [
{
"table_row": {
"cells": [
[{"text": {"content": "H1"}}],
[{"text": {"content": "H2"}}],
]
}
},
{
"table_row": {
"cells": [
[{"text": {"content": "R1C1"}}],
[{"text": {"content": "R1C2"}}],
]
}
},
],
"next_cursor": "next",
}
page_two = {
"results": [
{
"table_row": {
"cells": [
[{"text": {"content": "H1"}}],
[],
]
}
},
{
"table_row": {
"cells": [
[{"text": {"content": "R2C1"}}],
[{"text": {"content": "R2C2"}}],
]
}
},
],
"next_cursor": None,
}
mocker.patch("httpx.request", side_effect=[_mock_response(page_one), _mock_response(page_two)])
markdown = extractor._read_table_rows("tbl-1")
assert "| H1 | H2 |" in markdown
assert "| R1C1 | R1C2 |" in markdown
assert "| H1 | |" in markdown
assert "| R2C1 | R2C2 |" in markdown
class TestNotionMetadataAndCredentialMethods:
def test_update_last_edited_time_no_document_model(self):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
assert extractor.update_last_edited_time(None) is None
def test_update_last_edited_time_updates_document_and_commits(self, monkeypatch):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
class FakeDocumentModel:
data_source_info = "data_source_info"
update_calls = []
class FakeQuery:
def filter_by(self, **kwargs):
return self
def update(self, payload):
update_calls.append(payload)
class FakeSession:
committed = False
def query(self, model):
assert model is FakeDocumentModel
return FakeQuery()
def commit(self):
self.committed = True
fake_db = SimpleNamespace(session=FakeSession())
monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel)
monkeypatch.setattr(notion_extractor, "db", fake_db)
monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z")
doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"})
extractor.update_last_edited_time(doc_model)
assert update_calls
assert fake_db.session.committed is True
def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture):
extractor_page = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="page-id",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
request_mock = mocker.patch(
"httpx.request", return_value=_mock_response({"last_edited_time": "2025-05-01T00:00:00.000Z"})
)
assert extractor_page.get_notion_last_edited_time() == "2025-05-01T00:00:00.000Z"
assert "pages/page-id" in request_mock.call_args[0][1]
extractor_db = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="db-id",
notion_page_type="database",
tenant_id="tenant",
notion_access_token="token",
)
request_mock = mocker.patch(
"httpx.request", return_value=_mock_response({"last_edited_time": "2025-06-01T00:00:00.000Z"})
)
assert extractor_db.get_notion_last_edited_time() == "2025-06-01T00:00:00.000Z"
assert "databases/db-id" in request_mock.call_args[0][1]
def test_get_notion_last_edited_time_requires_access_token(self):
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="ws",
notion_obj_id="obj",
notion_page_type="page",
tenant_id="tenant",
notion_access_token="token",
)
extractor._notion_access_token = None
with pytest.raises(AssertionError, match="Notion access token is required"):
extractor.get_notion_last_edited_time()
def test_get_access_token_success_and_errors(self, monkeypatch):
with pytest.raises(Exception, match="No credential id found"):
notion_extractor.NotionExtractor._get_access_token("tenant", None)
class FakeProviderServiceMissing:
def get_datasource_credentials(self, **kwargs):
return None
monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceMissing)
with pytest.raises(Exception, match="No notion credential found"):
notion_extractor.NotionExtractor._get_access_token("tenant", "cred")
class FakeProviderServiceFound:
def get_datasource_credentials(self, **kwargs):
return {"integration_secret": "token-from-credential"}
monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceFound)
assert notion_extractor.NotionExtractor._get_access_token("tenant", "cred") == "token-from-credential"

View File

@ -0,0 +1,79 @@
from pathlib import Path
from types import SimpleNamespace
import pytest
import core.rag.extractor.text_extractor as text_module
from core.rag.extractor.text_extractor import TextExtractor
class TestTextExtractor:
def test_extract_success(self, tmp_path):
file_path = tmp_path / "data.txt"
file_path.write_text("hello world", encoding="utf-8")
extractor = TextExtractor(str(file_path))
docs = extractor.extract()
assert len(docs) == 1
assert docs[0].page_content == "hello world"
assert docs[0].metadata == {"source": str(file_path)}
def test_extract_autodetect_success_after_decode_error(self, monkeypatch):
extractor = TextExtractor("dummy.txt", autodetect_encoding=True)
calls = []
def fake_read_text(self, encoding=None):
calls.append(encoding)
if encoding is None:
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode")
if encoding == "bad":
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode")
return "decoded text"
monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True)
monkeypatch.setattr(
text_module,
"detect_file_encodings",
lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")],
)
docs = extractor.extract()
assert docs[0].page_content == "decoded text"
assert calls == [None, "bad", "utf-8"]
def test_extract_autodetect_all_fail_raises_runtime_error(self, monkeypatch):
extractor = TextExtractor("dummy.txt", autodetect_encoding=True)
def always_decode_error(self, encoding=None):
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode")
monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True)
monkeypatch.setattr(text_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="latin-1")])
with pytest.raises(RuntimeError, match="all detected encodings failed"):
extractor.extract()
def test_extract_decode_error_without_autodetect_raises_runtime_error(self, monkeypatch):
extractor = TextExtractor("dummy.txt", autodetect_encoding=False)
def always_decode_error(self, encoding=None):
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode")
monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True)
with pytest.raises(RuntimeError, match="specified encoding failed"):
extractor.extract()
def test_extract_wraps_non_decode_exceptions(self, monkeypatch):
extractor = TextExtractor("dummy.txt")
def raise_other(self, encoding=None):
raise OSError("io error")
monkeypatch.setattr(Path, "read_text", raise_other, raising=True)
with pytest.raises(RuntimeError, match="Error loading dummy.txt"):
extractor.extract()

View File

@ -3,9 +3,12 @@
import io
import os
import tempfile
from collections import UserDict
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from docx import Document
from docx.oxml import OxmlElement
from docx.oxml.ns import qn
@ -136,7 +139,7 @@ def test_extract_images_from_docx(monkeypatch):
monkeypatch.setattr(we, "UploadFile", FakeUploadFile)
# Patch external image fetcher
def fake_get(url: str):
def fake_get(url: str, **kwargs):
assert url == "https://example.com/image.png"
return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes)
@ -203,10 +206,8 @@ def test_extract_images_from_docx_uses_internal_files_url():
finally:
# Restore original values
if original_files_url is not None:
dify_config.FILES_URL = original_files_url
if original_internal_files_url is not None:
dify_config.INTERNAL_FILES_URL = original_internal_files_url
dify_config.FILES_URL = original_files_url
dify_config.INTERNAL_FILES_URL = original_internal_files_url
def test_extract_hyperlinks(monkeypatch):
@ -314,3 +315,313 @@ def test_extract_legacy_hyperlinks(monkeypatch):
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
def test_init_rejects_invalid_url_status(monkeypatch):
class FakeResponse:
status_code = 404
content = b""
closed = False
def close(self):
self.closed = True
fake_response = FakeResponse()
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=lambda url, **kwargs: fake_response))
with pytest.raises(ValueError, match="returned status code 404"):
WordExtractor("https://example.com/missing.docx", "tenant", "user")
assert fake_response.closed is True
def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
target_file = tmp_path / "expanded.docx"
target_file.write_bytes(b"docx")
monkeypatch.setattr(we.os.path, "expanduser", lambda p: str(target_file))
monkeypatch.setattr(
we.os.path,
"isfile",
lambda p: p == str(target_file),
)
extractor = WordExtractor("~/expanded.docx", "tenant", "user")
assert extractor.file_path == str(target_file)
monkeypatch.setattr(we.os.path, "isfile", lambda p: False)
with pytest.raises(ValueError, match="is not a valid file or url"):
WordExtractor("not-a-file", "tenant", "user")
def test_del_closes_temp_file():
extractor = object.__new__(WordExtractor)
extractor.temp_file = MagicMock()
WordExtractor.__del__(extractor)
extractor.temp_file.close.assert_called_once()
def test_extract_images_handles_invalid_external_cases(monkeypatch):
class FakeTargetRef:
def __contains__(self, item):
return item == "image"
def split(self, sep):
return [None]
rel_invalid_url = SimpleNamespace(is_external=True, target_ref="image-no-url")
rel_request_error = SimpleNamespace(is_external=True, target_ref="https://example.com/image-error")
rel_unknown_mime = SimpleNamespace(is_external=True, target_ref="https://example.com/image-unknown")
rel_internal_none_ext = SimpleNamespace(is_external=False, target_ref=FakeTargetRef(), target_part=object())
doc = SimpleNamespace(
part=SimpleNamespace(
rels={
"r1": rel_invalid_url,
"r2": rel_request_error,
"r3": rel_unknown_mime,
"r4": rel_internal_none_ext,
}
)
)
def fake_get(url, **kwargs):
if "image-error" in url:
raise RuntimeError("network")
return SimpleNamespace(status_code=200, headers={"Content-Type": "application/unknown"}, content=b"x")
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=MagicMock()))
monkeypatch.setattr(we, "db", db_stub)
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None))
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
extractor = object.__new__(WordExtractor)
extractor.tenant_id = "tenant"
extractor.user_id = "user"
result = extractor._extract_images_from_docx(doc)
assert result == {}
db_stub.session.commit.assert_called_once()
def test_table_to_markdown_and_parse_helpers(monkeypatch):
extractor = object.__new__(WordExtractor)
table = SimpleNamespace(
rows=[
SimpleNamespace(cells=[1, 2]),
SimpleNamespace(cells=[3, 4]),
]
)
parse_row_mock = MagicMock(side_effect=[["H1", "H2"], ["A", "B"]])
monkeypatch.setattr(extractor, "_parse_row", parse_row_mock)
markdown = extractor._table_to_markdown(table, {})
assert markdown == "| H1 | H2 |\n| --- | --- |\n| A | B |"
class FakeRunElement:
def __init__(self, blips):
self._blips = blips
def xpath(self, pattern):
if pattern == ".//a:blip":
return self._blips
return []
class FakeBlip:
def __init__(self, image_id):
self.image_id = image_id
def get(self, key):
return self.image_id
image_part = object()
paragraph = SimpleNamespace(
runs=[
SimpleNamespace(element=FakeRunElement([FakeBlip(None), FakeBlip("ext"), FakeBlip("int")]), text=""),
SimpleNamespace(element=FakeRunElement([]), text="plain"),
],
part=SimpleNamespace(
rels={
"ext": SimpleNamespace(is_external=True),
"int": SimpleNamespace(is_external=False, target_part=image_part),
}
),
)
image_map = {"ext": "EXT-IMG", image_part: "INT-IMG"}
assert extractor._parse_cell_paragraph(paragraph, image_map) == "EXT-IMGINT-IMGplain"
cell = SimpleNamespace(paragraphs=[paragraph, paragraph])
assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain"
def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch):
extractor = object.__new__(WordExtractor)
ext_image_id = "ext-image"
int_embed_id = "int-embed"
shape_ext_id = "shape-ext"
shape_int_id = "shape-int"
internal_part = object()
shape_internal_part = object()
class Rels(UserDict):
def get(self, key, default=None):
if key == "link-bad":
raise RuntimeError("cannot resolve relation")
return super().get(key, default)
rels = Rels(
{
ext_image_id: SimpleNamespace(is_external=True, target_ref="https://img/ext.png"),
int_embed_id: SimpleNamespace(is_external=False, target_part=internal_part),
shape_ext_id: SimpleNamespace(is_external=True, target_ref="https://img/shape.png"),
shape_int_id: SimpleNamespace(is_external=False, target_part=shape_internal_part),
"link-ok": SimpleNamespace(is_external=True, target_ref="https://example.com"),
}
)
image_map = {
ext_image_id: "[EXT]",
internal_part: "[INT]",
shape_ext_id: "[SHAPE_EXT]",
shape_internal_part: "[SHAPE_INT]",
}
class FakeBlip:
def __init__(self, embed_id):
self.embed_id = embed_id
def get(self, key):
return self.embed_id
class FakeDrawing:
def __init__(self, embed_ids):
self.embed_ids = embed_ids
def findall(self, pattern):
return [FakeBlip(embed_id) for embed_id in self.embed_ids]
class FakeNode:
def __init__(self, text=None, attrs=None):
self.text = text
self._attrs = attrs or {}
def get(self, key):
return self._attrs.get(key)
class FakeShape:
def __init__(self, bin_id=None, img_id=None):
self.bin_id = bin_id
self.img_id = img_id
def find(self, pattern):
if "binData" in pattern and self.bin_id:
return FakeNode(
text="shape",
attrs={"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id": self.bin_id},
)
if "imagedata" in pattern and self.img_id:
return FakeNode(attrs={"id": self.img_id})
return None
class FakeChild:
def __init__(
self,
tag,
text="",
fld_chars=None,
instr_texts=None,
drawings=None,
shapes=None,
attrs=None,
hyperlink_runs=None,
):
self.tag = tag
self.text = text
self._fld_chars = fld_chars or []
self._instr_texts = instr_texts or []
self._drawings = drawings or []
self._shapes = shapes or []
self._attrs = attrs or {}
self._hyperlink_runs = hyperlink_runs or []
def findall(self, pattern):
if pattern == qn("w:fldChar"):
return self._fld_chars
if pattern == qn("w:instrText"):
return self._instr_texts
if pattern == qn("w:r"):
return self._hyperlink_runs
if pattern.endswith("}drawing"):
return self._drawings
if pattern.endswith("}pict"):
return self._shapes
return []
def get(self, key):
return self._attrs.get(key)
class FakeRun:
def __init__(self, element, paragraph):
self.element = element
self.text = getattr(element, "text", "")
paragraph_main = SimpleNamespace(
_element=[
FakeChild(
qn("w:r"),
text="run-text",
drawings=[FakeDrawing([ext_image_id, int_embed_id])],
shapes=[FakeShape(bin_id=shape_ext_id, img_id=shape_int_id)],
),
FakeChild(
qn("w:r"),
text="",
drawings=[],
shapes=[FakeShape(bin_id=shape_ext_id)],
),
FakeChild(
qn("w:hyperlink"),
attrs={qn("r:id"): "link-ok"},
hyperlink_runs=[FakeChild(qn("w:r"), text="LinkText")],
),
FakeChild(
qn("w:hyperlink"),
attrs={qn("r:id"): "link-bad"},
hyperlink_runs=[FakeChild(qn("w:r"), text="BrokenLink")],
),
]
)
paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")])
fake_doc = SimpleNamespace(
part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}),
paragraphs=[paragraph_main, paragraph_empty],
tables=[SimpleNamespace(rows=[])],
element=SimpleNamespace(
body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")]
),
)
monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc)
monkeypatch.setattr(we, "Run", FakeRun)
monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map)
monkeypatch.setattr(extractor, "_table_to_markdown", lambda table, image_map: "TABLE-MARKDOWN")
logger_exception = MagicMock()
monkeypatch.setattr(we.logger, "exception", logger_exception)
content = extractor.parse_docx("dummy.docx")
assert "[EXT]" in content
assert "[INT]" in content
assert "[SHAPE_EXT]" in content
assert "[LinkText](https://example.com)" in content
assert "BrokenLink" in content
assert "TABLE-MARKDOWN" in content
logger_exception.assert_called_once()

View File

@ -0,0 +1,300 @@
"""Unit tests for unstructured extractors and their local/API partitioning paths."""
import base64
import sys
import types
from types import SimpleNamespace
import pytest
import core.rag.extractor.unstructured.unstructured_epub_extractor as epub_module
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
def _register_module(monkeypatch: pytest.MonkeyPatch, name: str, **attrs: object) -> types.ModuleType:
module = types.ModuleType(name)
for k, v in attrs.items():
setattr(module, k, v)
monkeypatch.setitem(sys.modules, name, module)
return module
def _register_unstructured_packages(monkeypatch: pytest.MonkeyPatch) -> None:
_register_module(monkeypatch, "unstructured", __path__=[])
_register_module(monkeypatch, "unstructured.partition", __path__=[])
_register_module(monkeypatch, "unstructured.chunking", __path__=[])
_register_module(monkeypatch, "unstructured.file_utils", __path__=[])
def _install_chunk_by_title(monkeypatch: pytest.MonkeyPatch, chunks: list[SimpleNamespace]) -> None:
_register_unstructured_packages(monkeypatch)
def chunk_by_title(
elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int
) -> list[SimpleNamespace]:
return chunks
_register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title)
class TestUnstructuredMarkdownMsgXml:
def test_markdown_extractor_without_api(self, monkeypatch):
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" chunk-1 "), SimpleNamespace(text=" chunk-2 ")])
_register_module(
monkeypatch, "unstructured.partition.md", partition_md=lambda filename: [SimpleNamespace(text="x")]
)
docs = UnstructuredMarkdownExtractor("/tmp/file.md").extract()
assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"]
def test_markdown_extractor_with_api(self, monkeypatch):
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" via-api ")])
calls = {}
def partition_via_api(filename, api_url, api_key):
calls.update({"filename": filename, "api_url": api_url, "api_key": api_key})
return [SimpleNamespace(text="ignored")]
_register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api)
docs = UnstructuredMarkdownExtractor("/tmp/file.md", api_url="https://u", api_key="k").extract()
assert docs[0].page_content == "via-api"
assert calls == {"filename": "/tmp/file.md", "api_url": "https://u", "api_key": "k"}
def test_msg_extractor_local(self, monkeypatch):
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")])
_register_module(
monkeypatch, "unstructured.partition.msg", partition_msg=lambda filename: [SimpleNamespace(text="x")]
)
assert UnstructuredMsgExtractor("/tmp/file.msg").extract()[0].page_content == "msg-doc"
def test_msg_extractor_with_api(self, monkeypatch):
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")])
calls = {}
def partition_via_api(filename, api_url, api_key):
calls.update({"filename": filename, "api_url": api_url, "api_key": api_key})
return [SimpleNamespace(text="x")]
_register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api)
assert (
UnstructuredMsgExtractor("/tmp/file.msg", api_url="https://u", api_key="k").extract()[0].page_content
== "msg-doc"
)
assert calls["filename"] == "/tmp/file.msg"
def test_xml_extractor_local_and_api(self, monkeypatch):
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="xml-doc")])
xml_calls = {}
def partition_xml(filename, xml_keep_tags):
xml_calls.update({"filename": filename, "xml_keep_tags": xml_keep_tags})
return [SimpleNamespace(text="x")]
_register_module(monkeypatch, "unstructured.partition.xml", partition_xml=partition_xml)
assert UnstructuredXmlExtractor("/tmp/file.xml").extract()[0].page_content == "xml-doc"
assert xml_calls == {"filename": "/tmp/file.xml", "xml_keep_tags": True}
api_calls = {}
def partition_via_api(filename, api_url, api_key):
api_calls.update({"filename": filename, "api_url": api_url, "api_key": api_key})
return [SimpleNamespace(text="x")]
_register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api)
assert (
UnstructuredXmlExtractor("/tmp/file.xml", api_url="https://u", api_key="k").extract()[0].page_content
== "xml-doc"
)
assert api_calls["filename"] == "/tmp/file.xml"
class TestUnstructuredEmailAndEpub:
def test_email_extractor_local_decodes_html_and_suppresses_decode_errors(self, monkeypatch):
_register_unstructured_packages(monkeypatch)
captured = {}
def chunk_by_title(
elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int
) -> list[SimpleNamespace]:
captured["elements"] = list(elements)
return [SimpleNamespace(text=" chunked-email ")]
_register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title)
html = "<p>Hello Email</p>"
encoded_html = base64.b64encode(html.encode("utf-8")).decode("utf-8")
bad_base64 = "not-base64"
elements = [SimpleNamespace(text=encoded_html), SimpleNamespace(text=bad_base64)]
_register_module(monkeypatch, "unstructured.partition.email", partition_email=lambda filename: elements)
docs = UnstructuredEmailExtractor("/tmp/file.eml").extract()
assert docs[0].page_content == "chunked-email"
chunk_elements = captured["elements"]
assert "Hello Email" in chunk_elements[0].text
assert chunk_elements[1].text == bad_base64
def test_email_extractor_with_api(self, monkeypatch):
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="api-email")])
_register_module(
monkeypatch,
"unstructured.partition.api",
partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="abc")],
)
docs = UnstructuredEmailExtractor("/tmp/file.eml", api_url="https://u", api_key="k").extract()
assert docs[0].page_content == "api-email"
def test_epub_extractor_local_and_api(self, monkeypatch):
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="epub-doc")])
calls = {"download": 0, "partition": 0}
def fake_download_pandoc():
calls["download"] += 1
def partition_epub(filename, xml_keep_tags):
calls["partition"] += 1
assert xml_keep_tags is True
return [SimpleNamespace(text="x")]
monkeypatch.setattr(epub_module.pypandoc, "download_pandoc", fake_download_pandoc)
_register_module(monkeypatch, "unstructured.partition.epub", partition_epub=partition_epub)
docs = UnstructuredEpubExtractor("/tmp/file.epub").extract()
assert docs[0].page_content == "epub-doc"
assert calls == {"download": 1, "partition": 1}
_register_module(
monkeypatch,
"unstructured.partition.api",
partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="x")],
)
docs = UnstructuredEpubExtractor("/tmp/file.epub", api_url="https://u", api_key="k").extract()
assert docs[0].page_content == "epub-doc"
class TestUnstructuredPPTAndPPTX:
def test_ppt_extractor_requires_api_url(self):
with pytest.raises(NotImplementedError, match="Unstructured API Url is not configured"):
UnstructuredPPTExtractor("/tmp/file.ppt").extract()
def test_ppt_extractor_groups_text_by_page(self, monkeypatch):
_register_unstructured_packages(monkeypatch)
_register_module(
monkeypatch,
"unstructured.partition.api",
partition_via_api=lambda filename, api_url, api_key: [
SimpleNamespace(text="A", metadata=SimpleNamespace(page_number=1)),
SimpleNamespace(text="B", metadata=SimpleNamespace(page_number=1)),
SimpleNamespace(text="skip", metadata=SimpleNamespace(page_number=None)),
SimpleNamespace(text="C", metadata=SimpleNamespace(page_number=2)),
],
)
docs = UnstructuredPPTExtractor("/tmp/file.ppt", api_url="https://u", api_key="k").extract()
assert [doc.page_content for doc in docs] == ["A\nB", "C"]
def test_pptx_extractor_local_and_api(self, monkeypatch):
_register_unstructured_packages(monkeypatch)
_register_module(
monkeypatch,
"unstructured.partition.pptx",
partition_pptx=lambda filename: [
SimpleNamespace(text="P1", metadata=SimpleNamespace(page_number=1)),
SimpleNamespace(text="P2", metadata=SimpleNamespace(page_number=2)),
SimpleNamespace(text="Skip", metadata=SimpleNamespace(page_number=None)),
],
)
docs = UnstructuredPPTXExtractor("/tmp/file.pptx").extract()
assert [doc.page_content for doc in docs] == ["P1", "P2"]
_register_module(
monkeypatch,
"unstructured.partition.api",
partition_via_api=lambda filename, api_url, api_key: [
SimpleNamespace(text="X", metadata=SimpleNamespace(page_number=1)),
SimpleNamespace(text="Y", metadata=SimpleNamespace(page_number=1)),
],
)
docs = UnstructuredPPTXExtractor("/tmp/file.pptx", api_url="https://u", api_key="k").extract()
assert [doc.page_content for doc in docs] == ["X\nY"]
class TestUnstructuredWord:
def _install_doc_modules(self, monkeypatch, version: str, filetype_value):
_register_unstructured_packages(monkeypatch)
class FileType:
DOC = "doc"
_register_module(monkeypatch, "unstructured.__version__", __version__=version)
_register_module(
monkeypatch,
"unstructured.file_utils.filetype",
FileType=FileType,
detect_filetype=lambda filename: filetype_value,
)
_register_module(
monkeypatch,
"unstructured.partition.api",
partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="api-doc")],
)
_register_module(
monkeypatch,
"unstructured.partition.docx",
partition_docx=lambda filename: [SimpleNamespace(text="docx-doc")],
)
_register_module(
monkeypatch,
"unstructured.chunking.title",
chunk_by_title=lambda elements, max_characters, combine_text_under_n_chars: [
SimpleNamespace(text="chunk-1"),
SimpleNamespace(text="chunk-2"),
],
)
def test_word_extractor_rejects_doc_on_old_unstructured_version(self, monkeypatch):
self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="doc")
with pytest.raises(ValueError, match="Partitioning .doc files is only supported"):
UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract()
def test_word_extractor_doc_and_docx_paths(self, monkeypatch):
self._install_doc_modules(monkeypatch, version="0.4.11", filetype_value="doc")
docs = UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract()
assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"]
self._install_doc_modules(monkeypatch, version="0.5.0", filetype_value="not-doc")
docs = UnstructuredWordExtractor("/tmp/file.docx", "https://u", "k").extract()
assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"]
def test_word_extractor_magic_import_error_fallback_to_extension(self, monkeypatch):
self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="not-used")
monkeypatch.setitem(sys.modules, "magic", None)
with pytest.raises(ValueError, match="Partitioning .doc files is only supported"):
UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract()

View File

@ -0,0 +1,434 @@
"""Unit tests for WaterCrawl client, provider, and extractor behavior."""
import json
from typing import Any
from unittest.mock import MagicMock
import pytest
import core.rag.extractor.watercrawl.client as client_module
from core.rag.extractor.watercrawl.client import BaseAPIClient, WaterCrawlAPIClient
from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError,
WaterCrawlBadRequestError,
WaterCrawlPermissionError,
)
from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
def _response(
status_code: int,
json_data: dict[str, Any] | None = None,
content_type: str = "application/json",
content: bytes = b"",
text: str = "",
) -> MagicMock:
response = MagicMock()
response.status_code = status_code
response.headers = {"Content-Type": content_type}
response.content = content
response.text = text
response.json.return_value = json_data if json_data is not None else {}
response.raise_for_status.return_value = None
response.close.return_value = None
return response
class TestWaterCrawlExceptions:
def test_bad_request_error_properties_and_string(self):
response = _response(400, {"message": "bad request", "errors": {"url": ["invalid"]}})
err = WaterCrawlBadRequestError(response)
parsed_errors = json.loads(err.flat_errors)
assert err.status_code == 400
assert err.message == "bad request"
assert "url" in parsed_errors
assert any("invalid" in str(item) for item in parsed_errors["url"])
assert "WaterCrawlBadRequestError" in str(err)
def test_permission_and_authentication_error_strings(self):
response = _response(403, {"message": "quota exceeded", "errors": {}})
permission = WaterCrawlPermissionError(response)
authentication = WaterCrawlAuthenticationError(response)
assert "exceeding your WaterCrawl API limits" in str(permission)
assert "API key is invalid or expired" in str(authentication)
class TestBaseAPIClient:
def test_init_session_builds_expected_headers(self, monkeypatch):
captured = {}
def fake_client(**kwargs):
captured.update(kwargs)
return "session"
monkeypatch.setattr(client_module.httpx, "Client", fake_client)
client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev")
assert client.session == "session"
assert captured["headers"]["X-API-Key"] == "k"
assert captured["headers"]["User-Agent"] == "WaterCrawl-Plugin"
def test_request_stream_and_non_stream_paths(self, monkeypatch):
class FakeSession:
def __init__(self):
self.request_calls = []
self.build_calls = []
self.send_calls = []
def request(self, method, url, params=None, json=None, **kwargs):
self.request_calls.append((method, url, params, json, kwargs))
return "non-stream-response"
def build_request(self, method, url, params=None, json=None):
req = (method, url, params, json)
self.build_calls.append(req)
return req
def send(self, request, stream=False, **kwargs):
self.send_calls.append((request, stream, kwargs))
return "stream-response"
fake_session = FakeSession()
monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: fake_session)
client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev")
assert client._request("GET", "/v1/items", query_params={"a": 1}) == "non-stream-response"
assert fake_session.request_calls[0][1] == "https://watercrawl.dev/v1/items"
assert client._request("GET", "/v1/items", stream=True) == "stream-response"
assert fake_session.build_calls
assert fake_session.send_calls[0][1] is True
def test_http_method_helpers_delegate_to_request(self, monkeypatch):
monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: MagicMock())
client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev")
calls = []
def fake_request(method, endpoint, query_params=None, data=None, **kwargs):
calls.append((method, endpoint, query_params, data))
return "ok"
monkeypatch.setattr(client, "_request", fake_request)
assert client._get("/a") == "ok"
assert client._post("/b", data={"x": 1}) == "ok"
assert client._put("/c", data={"x": 2}) == "ok"
assert client._delete("/d") == "ok"
assert client._patch("/e", data={"x": 3}) == "ok"
assert [c[0] for c in calls] == ["GET", "POST", "PUT", "DELETE", "PATCH"]
class TestWaterCrawlAPIClient:
def test_process_eventstream_and_download(self, monkeypatch):
client = WaterCrawlAPIClient(api_key="k")
response = MagicMock()
response.iter_lines.return_value = [
b"event: keep-alive",
b'data: {"type":"result","data":{"result":"http://x"}}',
b'data: {"type":"log","data":{"msg":"ok"}}',
]
monkeypatch.setattr(client, "download_result", lambda data: {"result": {"markdown": "body"}, "url": "u"})
events = list(client.process_eventstream(response, download=True))
assert events[0]["data"]["result"]["markdown"] == "body"
assert events[1]["type"] == "log"
response.close.assert_called_once()
@pytest.mark.parametrize(
("status", "expected_exception"),
[
(401, WaterCrawlAuthenticationError),
(403, WaterCrawlPermissionError),
(422, WaterCrawlBadRequestError),
],
)
def test_process_response_error_statuses(self, status: int, expected_exception: type[Exception]):
client = WaterCrawlAPIClient(api_key="k")
with pytest.raises(expected_exception):
client.process_response(_response(status, {"message": "bad", "errors": {"url": ["x"]}}))
def test_process_response_204_returns_none(self):
client = WaterCrawlAPIClient(api_key="k")
assert client.process_response(_response(204, None)) is None
def test_process_response_json_payloads(self):
client = WaterCrawlAPIClient(api_key="k")
assert client.process_response(_response(200, {"ok": True})) == {"ok": True}
assert client.process_response(_response(200, None)) == {}
def test_process_response_octet_stream_returns_bytes(self):
client = WaterCrawlAPIClient(api_key="k")
assert (
client.process_response(_response(200, content_type="application/octet-stream", content=b"bin")) == b"bin"
)
def test_process_response_event_stream_returns_generator(self, monkeypatch):
client = WaterCrawlAPIClient(api_key="k")
generator = (item for item in [{"type": "result", "data": {}}])
monkeypatch.setattr(client, "process_eventstream", lambda response, download=False: generator)
assert client.process_response(_response(200, content_type="text/event-stream")) is generator
def test_process_response_unknown_content_type_raises(self):
client = WaterCrawlAPIClient(api_key="k")
with pytest.raises(Exception, match="Unknown response type"):
client.process_response(_response(200, content_type="text/plain", text="x"))
def test_process_response_uses_raise_for_status(self):
client = WaterCrawlAPIClient(api_key="k")
response = _response(500, {"message": "server"})
response.raise_for_status.side_effect = RuntimeError("http error")
with pytest.raises(RuntimeError, match="http error"):
client.process_response(response)
def test_endpoint_wrappers(self, monkeypatch):
client = WaterCrawlAPIClient(api_key="k")
monkeypatch.setattr(client, "process_response", lambda resp: "processed")
monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "get-resp")
monkeypatch.setattr(client, "_post", lambda *args, **kwargs: "post-resp")
monkeypatch.setattr(client, "_delete", lambda *args, **kwargs: "delete-resp")
assert client.get_crawl_requests_list() == "processed"
assert client.get_crawl_request("id") == "processed"
assert client.create_crawl_request(url="https://x") == "processed"
assert client.stop_crawl_request("id") == "processed"
assert client.download_crawl_request("id") == "processed"
assert client.get_crawl_request_results("id") == "processed"
def test_monitor_crawl_request_generator_and_validation(self, monkeypatch):
client = WaterCrawlAPIClient(api_key="k")
monkeypatch.setattr(client, "process_response", lambda _: (x for x in [{"type": "result", "data": 1}]))
monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "stream-resp")
events = list(client.monitor_crawl_request("job-1", prefetched=True))
assert events == [{"type": "result", "data": 1}]
monkeypatch.setattr(client, "process_response", lambda _: [{"type": "result"}])
with pytest.raises(ValueError, match="Generator expected"):
list(client.monitor_crawl_request("job-1"))
def test_scrape_url_sync_and_async(self, monkeypatch):
client = WaterCrawlAPIClient(api_key="k")
monkeypatch.setattr(client, "create_crawl_request", lambda **kwargs: {"uuid": "job-1"})
async_result = client.scrape_url("https://example.com", sync=False)
assert async_result == {"uuid": "job-1"}
monkeypatch.setattr(
client,
"monitor_crawl_request",
lambda item_id, prefetched: iter(
[{"type": "log", "data": {}}, {"type": "result", "data": {"url": "https://example.com"}}]
),
)
sync_result = client.scrape_url("https://example.com", sync=True)
assert sync_result == {"url": "https://example.com"}
def test_download_result_fetches_json_and_closes(self, monkeypatch):
client = WaterCrawlAPIClient(api_key="k")
response = _response(200, {"markdown": "body"})
monkeypatch.setattr(client_module.httpx, "get", lambda *args, **kwargs: response)
result = client.download_result({"result": "https://example.com/result.json"})
assert result["result"] == {"markdown": "body"}
response.close.assert_called_once()
class TestWaterCrawlProvider:
def test_crawl_url_builds_options_and_min_wait_time(self, monkeypatch):
provider = WaterCrawlProvider(api_key="k")
captured_kwargs = {}
def create_crawl_request_spy(**kwargs):
captured_kwargs.update(kwargs)
return {"uuid": "job-1"}
monkeypatch.setattr(provider.client, "create_crawl_request", create_crawl_request_spy)
result = provider.crawl_url(
"https://example.com",
{
"crawl_sub_pages": True,
"limit": 5,
"max_depth": 2,
"includes": "a,b",
"excludes": "x,y",
"exclude_tags": "nav,footer",
"include_tags": "main",
"wait_time": 100,
"only_main_content": False,
},
)
assert result == {"status": "active", "job_id": "job-1"}
assert captured_kwargs["url"] == "https://example.com"
assert captured_kwargs["spider_options"] == {
"max_depth": 2,
"page_limit": 5,
"allowed_domains": [],
"exclude_paths": ["x", "y"],
"include_paths": ["a", "b"],
}
assert captured_kwargs["page_options"]["exclude_tags"] == ["nav", "footer"]
assert captured_kwargs["page_options"]["include_tags"] == ["main"]
assert captured_kwargs["page_options"]["only_main_content"] is False
assert captured_kwargs["page_options"]["wait_time"] == 1000
def test_get_crawl_status_active_and_completed(self, monkeypatch):
provider = WaterCrawlProvider(api_key="k")
monkeypatch.setattr(
provider.client,
"get_crawl_request",
lambda job_id: {
"status": "running",
"uuid": job_id,
"options": {"spider_options": {"page_limit": 3}},
"number_of_documents": 1,
"duration": "00:00:01.500000",
},
)
active = provider.get_crawl_status("job-1")
assert active["status"] == "active"
assert active["data"] == []
assert active["time_consuming"] == pytest.approx(1.5)
monkeypatch.setattr(
provider.client,
"get_crawl_request",
lambda job_id: {
"status": "completed",
"uuid": job_id,
"options": {"spider_options": {"page_limit": 2}},
"number_of_documents": 2,
"duration": "00:00:02.000000",
},
)
monkeypatch.setattr(provider, "_get_results", lambda crawl_request_id, query_params=None: iter([{"url": "u"}]))
completed = provider.get_crawl_status("job-2")
assert completed["status"] == "completed"
assert completed["data"] == [{"url": "u"}]
def test_get_crawl_url_data_and_scrape(self, monkeypatch):
provider = WaterCrawlProvider(api_key="k")
monkeypatch.setattr(provider, "scrape_url", lambda url: {"source_url": url})
assert provider.get_crawl_url_data("", "https://example.com") == {"source_url": "https://example.com"}
monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([{"source_url": "u1"}]))
assert provider.get_crawl_url_data("job", "u1") == {"source_url": "u1"}
monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([]))
assert provider.get_crawl_url_data("job", "u1") is None
def test_structure_data_validation_and_get_results_pagination(self, monkeypatch):
provider = WaterCrawlProvider(api_key="k")
with pytest.raises(ValueError, match="Invalid result object"):
provider._structure_data({"result": "not-a-dict"})
structured = provider._structure_data(
{
"url": "https://example.com",
"result": {
"metadata": {"title": "Title", "description": "Desc"},
"markdown": "Body",
},
}
)
assert structured["title"] == "Title"
assert structured["markdown"] == "Body"
responses = [
{
"results": [
{
"url": "https://a",
"result": {"metadata": {"title": "A", "description": "DA"}, "markdown": "MA"},
}
],
"next": "next-page",
},
{"results": [], "next": None},
]
monkeypatch.setattr(
provider.client,
"get_crawl_request_results",
lambda crawl_request_id, page, page_size, query_params: responses.pop(0),
)
results = list(provider._get_results("job-1"))
assert len(results) == 1
assert results[0]["source_url"] == "https://a"
def test_scrape_url_uses_client_and_structure(self, monkeypatch):
provider = WaterCrawlProvider(api_key="k")
monkeypatch.setattr(
provider.client, "scrape_url", lambda **kwargs: {"result": {"metadata": {}, "markdown": "m"}, "url": "u"}
)
result = provider.scrape_url("u")
assert result["source_url"] == "u"
class TestWaterCrawlWebExtractor:
def test_extract_crawl_and_scrape_modes(self, monkeypatch):
monkeypatch.setattr(
"core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data",
lambda job_id, provider, url, tenant_id: {
"markdown": "crawl",
"source_url": url,
"description": "d",
"title": "t",
},
)
monkeypatch.setattr(
"core.rag.extractor.watercrawl.extractor.WebsiteService.get_scrape_url_data",
lambda provider, url, tenant_id, only_main_content: {
"markdown": "scrape",
"source_url": url,
"description": "d",
"title": "t",
},
)
crawl_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
scrape_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape")
assert crawl_extractor.extract()[0].page_content == "crawl"
assert scrape_extractor.extract()[0].page_content == "scrape"
def test_extract_crawl_returns_empty_when_service_returns_none(self, monkeypatch):
monkeypatch.setattr(
"core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data",
lambda job_id, provider, url, tenant_id: None,
)
extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
assert extractor.extract() == []
def test_extract_unknown_mode_returns_empty(self):
extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="other")
assert extractor.extract() == []

View File

@ -0,0 +1,33 @@
from contextlib import AbstractContextManager, nullcontext
from typing import Any
import pytest
class _FakeFlaskApp:
def app_context(self) -> AbstractContextManager[None]:
return nullcontext()
class _FakeExecutor:
def __init__(self, future: Any) -> None:
self._future = future
def __enter__(self) -> "_FakeExecutor":
return self
def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> bool:
return False
def submit(self, func: object, preview: object) -> Any:
return self._future
@pytest.fixture
def fake_flask_app() -> _FakeFlaskApp:
return _FakeFlaskApp()
@pytest.fixture
def fake_executor_cls() -> type[_FakeExecutor]:
return _FakeExecutor

View File

@ -0,0 +1,629 @@
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.models.document import AttachmentDocument, Document
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent
from dify_graph.model_runtime.entities.model_entities import ModelFeature
class TestParagraphIndexProcessor:
@pytest.fixture
def processor(self) -> ParagraphIndexProcessor:
return ParagraphIndexProcessor()
@pytest.fixture
def dataset(self) -> Mock:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.is_multimodal = True
return dataset
@pytest.fixture
def dataset_document(self) -> Mock:
document = Mock()
document.id = "doc-1"
document.created_by = "user-1"
return document
@pytest.fixture
def process_rule(self) -> dict:
return {
"mode": "custom",
"rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}},
}
def _rules(self) -> SimpleNamespace:
segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n")
return SimpleNamespace(segmentation=segmentation)
def _llm_result(self, content: str = "summary") -> LLMResult:
return LLMResult(
model="llm-model",
message=AssistantPromptMessage(content=content),
usage=LLMUsage.empty_usage(),
)
def test_extract_forwards_automatic_flag(self, processor: ParagraphIndexProcessor) -> None:
extract_setting = Mock()
expected_docs = [Document(page_content="chunk", metadata={})]
with patch(
"core.rag.index_processor.processor.paragraph_index_processor.ExtractProcessor.extract"
) as mock_extract:
mock_extract.return_value = expected_docs
docs = processor.extract(extract_setting, process_rule_mode="hierarchical")
assert docs == expected_docs
mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True)
def test_transform_validates_process_rule(self, processor: ParagraphIndexProcessor) -> None:
with pytest.raises(ValueError, match="No process rule found"):
processor.transform([Document(page_content="text", metadata={})], process_rule=None)
with pytest.raises(ValueError, match="No rules found in process rule"):
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"})
def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None:
rules_without_segmentation = SimpleNamespace(segmentation=None)
with patch(
"core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate",
return_value=rules_without_segmentation,
):
with pytest.raises(ValueError, match="No segmentation found in rules"):
processor.transform(
[Document(page_content="text", metadata={})],
process_rule={"mode": "custom", "rules": {"enabled": True}},
)
def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None:
source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"})
splitter = Mock()
splitter.split_documents.return_value = [
Document(page_content=".first", metadata={}),
Document(page_content=" ", metadata={}),
]
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate",
return_value=self._rules(),
),
patch.object(processor, "_get_splitter", return_value=splitter),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean",
return_value=".first",
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols",
side_effect=lambda text: text.lstrip("."),
),
patch.object(
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})]
),
):
documents = processor.transform([source_document], process_rule=process_rule)
assert len(documents) == 1
assert documents[0].page_content == "first"
assert documents[0].attachments is not None
assert documents[0].metadata["doc_hash"] == "hash"
def test_transform_automatic_mode_uses_default_rules(self, processor: ParagraphIndexProcessor) -> None:
splitter = Mock()
splitter.split_documents.return_value = [Document(page_content="text", metadata={})]
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate",
return_value=self._rules(),
) as mock_validate,
patch.object(processor, "_get_splitter", return_value=splitter),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean",
side_effect=lambda text, _: text,
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols",
side_effect=lambda text: text,
),
patch.object(processor, "_get_content_files", return_value=[]),
):
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "automatic"})
assert mock_validate.call_count == 1
def test_load_creates_vector_and_multimodal_when_high_quality(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
docs = [Document(page_content="chunk", metadata={})]
multimodal_docs = [AttachmentDocument(page_content="image", metadata={})]
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls,
patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls,
):
processor.load(dataset, docs, multimodal_documents=multimodal_docs)
vector = mock_vector_cls.return_value
vector.create.assert_called_once_with(docs)
vector.create_multimodal.assert_called_once_with(multimodal_docs)
mock_keyword_cls.assert_not_called()
def test_load_uses_keyword_add_texts_with_keywords_when_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
docs = [Document(page_content="chunk", metadata={})]
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
processor.load(dataset, docs, keywords_list=["k1", "k2"])
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs, keywords_list=["k1", "k2"])
def test_load_uses_keyword_add_texts_without_keywords_when_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
docs = [Document(page_content="chunk", metadata={})]
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
processor.load(dataset, docs)
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs)
def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
segment_query = Mock()
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
session = Mock()
session.query.return_value = segment_query
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments"
) as mock_summary,
patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls,
):
vector = mock_vector_cls.return_value
processor.clean(dataset, ["node-1"], delete_summaries=True)
mock_summary.assert_called_once_with(dataset, ["seg-1"])
vector.delete_by_ids.assert_called_once_with(["node-1"])
def test_clean_economy_deletes_summaries_and_keywords(
self, processor: ParagraphIndexProcessor, dataset: Mock
) -> None:
dataset.indexing_technique = "economy"
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments"
) as mock_summary,
patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls,
):
processor.clean(dataset, None, delete_summaries=True)
mock_summary.assert_called_once_with(dataset, None)
mock_keyword_cls.return_value.delete.assert_called_once()
def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
dataset.indexing_technique = "economy"
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
processor.clean(dataset, ["node-2"], with_keywords=True)
mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"])
def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9)
rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1)
with patch(
"core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve"
) as mock_retrieve:
mock_retrieve.return_value = [accepted, rejected]
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.9
def test_index_list_chunks_high_quality(
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch.object(
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})]
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"
) as mock_store_cls,
patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls,
):
processor.index(dataset, dataset_document, ["chunk-1", "chunk-2"])
mock_store_cls.return_value.add_documents.assert_called_once()
mock_vector_cls.return_value.create.assert_called_once()
mock_vector_cls.return_value.create_multimodal.assert_called_once()
def test_index_list_chunks_economy(
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
dataset.indexing_technique = "economy"
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch.object(processor, "_get_content_files", return_value=[]),
patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"),
patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls,
):
processor.index(dataset, dataset_document, ["chunk-3"])
mock_keyword_cls.return_value.add_texts.assert_called_once()
def test_index_multimodal_structure_handles_files_and_account_lookup(
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
chunk_with_files = SimpleNamespace(
content="content-1",
files=[SimpleNamespace(id="file-1", filename="image.png")],
)
chunk_without_files = SimpleNamespace(content="content-2", files=None)
structure = SimpleNamespace(general_chunks=[chunk_with_files, chunk_without_files])
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate",
return_value=structure,
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user",
return_value=SimpleNamespace(id="user-1"),
),
patch.object(
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})]
) as mock_files,
patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"),
patch("core.rag.index_processor.processor.paragraph_index_processor.Vector"),
):
processor.index(dataset, dataset_document, {"general_chunks": []})
assert mock_files.call_count == 1
def test_index_multimodal_structure_requires_valid_account(
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
structure = SimpleNamespace(general_chunks=[SimpleNamespace(content="content", files=None)])
with (
patch(
"core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate",
return_value=structure,
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user",
return_value=None,
),
):
with pytest.raises(ValueError, match="Invalid account"):
processor.index(dataset, dataset_document, {"general_chunks": []})
def test_format_preview_validates_chunk_shape(self, processor: ParagraphIndexProcessor) -> None:
preview = processor.format_preview(["chunk-1", "chunk-2"])
assert preview["chunk_structure"] == "text_model"
assert preview["total_segments"] == 2
with pytest.raises(ValueError, match="Chunks is not a list"):
processor.format_preview({"not": "a-list"})
def test_generate_summary_preview_success_and_failure(self, processor: ParagraphIndexProcessor) -> None:
preview_items = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")]
with patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())):
result = processor.generate_summary_preview(
"tenant-1", preview_items, {"enable": True}, doc_language="English"
)
assert all(item.summary == "summary" for item in result)
with patch.object(processor, "generate_summary", side_effect=RuntimeError("summary failed")):
with pytest.raises(ValueError, match="Failed to generate summaries"):
processor.generate_summary_preview("tenant-1", [PreviewDetail(content="chunk-1")], {"enable": True})
def test_generate_summary_preview_fallback_without_flask_context(self, processor: ParagraphIndexProcessor) -> None:
preview_items = [PreviewDetail(content="chunk-1")]
fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app")))
with (
patch("flask.current_app", fake_current_app),
patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())),
):
result = processor.generate_summary_preview("tenant-1", preview_items, {"enable": True})
assert result[0].summary == "summary"
def test_generate_summary_preview_timeout(
self, processor: ParagraphIndexProcessor, fake_executor_cls: type
) -> None:
preview_items = [PreviewDetail(content="chunk-1")]
future = Mock()
executor = fake_executor_cls(future)
with (
patch("concurrent.futures.ThreadPoolExecutor", return_value=executor),
patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]),
):
with pytest.raises(ValueError, match="timeout"):
processor.generate_summary_preview("tenant-1", preview_items, {"enable": True})
future.cancel.assert_called_once()
def test_generate_summary_validates_input(self) -> None:
with pytest.raises(ValueError, match="must be enabled"):
ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": False})
with pytest.raises(ValueError, match="model_name and model_provider_name"):
ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True})
def test_generate_summary_text_only_flow(self) -> None:
model_instance = Mock()
model_instance.credentials = {"k": "v"}
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[])
model_instance.invoke_llm.return_value = self._llm_result("text summary")
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
return_value=model_instance,
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota",
side_effect=RuntimeError("quota"),
),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
summary, usage = ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
document_language="English",
)
assert summary == "text summary"
assert isinstance(usage, LLMUsage)
mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota")
def test_generate_summary_handles_vision_and_image_conversion(self) -> None:
model_instance = Mock()
model_instance.credentials = {"k": "v"}
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
features=[ModelFeature.VISION]
)
model_instance.invoke_llm.return_value = self._llm_result("vision summary")
image_file = SimpleNamespace()
image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png")
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
return_value=model_instance,
),
patch.object(
ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[image_file]
),
patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[]) as mock_extract_text,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content",
return_value=image_content,
),
patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"),
):
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
summary, _ = ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
segment_id="seg-1",
)
assert summary == "vision summary"
mock_extract_text.assert_not_called()
def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None:
model_instance = Mock()
model_instance.credentials = {"k": "v"}
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
features=[ModelFeature.VISION]
)
model_instance.invoke_llm.return_value = object()
image_file = SimpleNamespace()
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
patch(
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
return_value=model_instance,
),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.DEFAULT_GENERATOR_SUMMARY_PROMPT",
"Prompt {missing}",
),
patch.object(ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[]),
patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[image_file]),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content",
side_effect=RuntimeError("bad image"),
),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
with pytest.raises(ValueError, match="Expected LLMResult"):
ParagraphIndexProcessor.generate_summary(
"tenant-1",
"text content",
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
)
mock_logger.warning.assert_called_with(
"Failed to convert image file to prompt message content: %s", "bad image"
)
def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None:
text = (
"![img](/files/11111111-1111-1111-1111-111111111111/image-preview) "
"![img2](/files/22222222-2222-2222-2222-222222222222/file-preview) "
"![tool](/files/tools/33333333-3333-3333-3333-333333333333.png)"
)
image_upload = SimpleNamespace(
id="11111111-1111-1111-1111-111111111111",
tenant_id="tenant-1",
name="image.png",
mime_type="image/png",
extension="png",
source_url="",
size=1,
key="key",
)
non_image_upload = SimpleNamespace(
id="22222222-2222-2222-2222-222222222222",
tenant_id="tenant-1",
name="file.txt",
mime_type="text/plain",
extension="txt",
source_url="",
size=1,
key="key",
)
query = Mock()
query.where.return_value.all.return_value = [image_upload, non_image_upload]
session = Mock()
session.query.return_value = query
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping",
return_value=SimpleNamespace(id="file-1"),
) as mock_builder,
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text)
assert len(files) == 1
assert mock_builder.call_count == 1
mock_logger.warning.assert_not_called()
def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None:
assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here") == []
def test_extract_images_from_text_logs_when_build_fails(self) -> None:
text = "![img](/files/11111111-1111-1111-1111-111111111111/image-preview)"
image_upload = SimpleNamespace(
id="11111111-1111-1111-1111-111111111111",
tenant_id="tenant-1",
name="image.png",
mime_type="image/png",
extension="png",
source_url="",
size=1,
key="key",
)
query = Mock()
query.where.return_value.all.return_value = [image_upload]
session = Mock()
session.query.return_value = query
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping",
side_effect=RuntimeError("build failed"),
),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text)
assert files == []
mock_logger.warning.assert_called_once()
def test_extract_images_from_segment_attachments(self) -> None:
image_upload = SimpleNamespace(
id="file-1",
name="image",
extension="png",
mime_type="image/png",
source_url="",
size=1,
key="k1",
)
bad_upload = SimpleNamespace(
id="file-2",
name="broken",
extension=None,
mime_type="image/png",
source_url="",
size=1,
key="k2",
)
non_image_upload = SimpleNamespace(
id="file-3",
name="text",
extension="txt",
mime_type="text/plain",
source_url="",
size=1,
key="k3",
)
execute_result = Mock()
execute_result.all.return_value = [(None, image_upload), (None, bad_upload), (None, non_image_upload)]
session = Mock()
session.execute.return_value = execute_result
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
):
files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1")
assert len(files) == 1
mock_logger.warning.assert_called_once()
def test_extract_images_from_segment_attachments_empty(self) -> None:
execute_result = Mock()
execute_result.all.return_value = []
session = Mock()
session.execute.return_value = execute_result
with patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session):
empty_files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1")
assert empty_files == []

View File

@ -0,0 +1,523 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from services.entities.knowledge_entities.knowledge_entities import ParentMode
class TestParentChildIndexProcessor:
@pytest.fixture
def processor(self) -> ParentChildIndexProcessor:
return ParentChildIndexProcessor()
@pytest.fixture
def dataset(self) -> Mock:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.is_multimodal = True
return dataset
@pytest.fixture
def dataset_document(self) -> Mock:
document = Mock()
document.id = "doc-1"
document.created_by = "user-1"
document.dataset_process_rule_id = None
return document
def _segmentation(self) -> SimpleNamespace:
return SimpleNamespace(max_tokens=200, chunk_overlap=10, separator="\n")
def _paragraph_rules(self) -> SimpleNamespace:
return SimpleNamespace(
parent_mode=ParentMode.PARAGRAPH,
segmentation=self._segmentation(),
subchunk_segmentation=self._segmentation(),
)
def _full_doc_rules(self) -> SimpleNamespace:
return SimpleNamespace(
parent_mode=ParentMode.FULL_DOC, segmentation=None, subchunk_segmentation=self._segmentation()
)
def test_extract_forwards_automatic_flag(self, processor: ParentChildIndexProcessor) -> None:
extract_setting = Mock()
expected = [Document(page_content="chunk", metadata={})]
with patch(
"core.rag.index_processor.processor.parent_child_index_processor.ExtractProcessor.extract"
) as mock_extract:
mock_extract.return_value = expected
documents = processor.extract(extract_setting, process_rule_mode="hierarchical")
assert documents == expected
mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True)
def test_transform_validates_process_rule(self, processor: ParentChildIndexProcessor) -> None:
with pytest.raises(ValueError, match="No process rule found"):
processor.transform([Document(page_content="text", metadata={})], process_rule=None)
with pytest.raises(ValueError, match="No rules found in process rule"):
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"})
def test_transform_paragraph_requires_segmentation(self, processor: ParentChildIndexProcessor) -> None:
rules = SimpleNamespace(parent_mode=ParentMode.PARAGRAPH, segmentation=None)
with patch(
"core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", return_value=rules
):
with pytest.raises(ValueError, match="No segmentation found in rules"):
processor.transform(
[Document(page_content="text", metadata={})],
process_rule={"mode": "custom", "rules": {"enabled": True}},
)
def test_transform_paragraph_builds_parent_and_child_docs(self, processor: ParentChildIndexProcessor) -> None:
splitter = Mock()
splitter.split_documents.return_value = [
Document(page_content=".parent", metadata={}),
Document(page_content=" ", metadata={}),
]
parent_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"})
child_docs = [ChildDocument(page_content="child-1", metadata={"dataset_id": "dataset-1"})]
with (
patch(
"core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate",
return_value=self._paragraph_rules(),
),
patch.object(processor, "_get_splitter", return_value=splitter),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean",
return_value=".parent",
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch.object(
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})]
),
patch.object(processor, "_split_child_nodes", return_value=child_docs),
):
result = processor.transform(
[parent_document],
process_rule={"mode": "custom", "rules": {"enabled": True}},
preview=False,
)
assert len(result) == 1
assert result[0].page_content == "parent"
assert result[0].children == child_docs
assert result[0].attachments is not None
def test_transform_preview_returns_after_ten_parent_chunks(self, processor: ParentChildIndexProcessor) -> None:
splitter = Mock()
splitter.split_documents.return_value = [Document(page_content=f"chunk-{i}", metadata={}) for i in range(10)]
documents = [
Document(page_content="doc-1", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}),
Document(page_content="doc-2", metadata={"dataset_id": "dataset-1", "document_id": "doc-2"}),
]
with (
patch(
"core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate",
return_value=self._paragraph_rules(),
),
patch.object(processor, "_get_splitter", return_value=splitter),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean",
side_effect=lambda text, _: text,
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch.object(processor, "_get_content_files", return_value=[]),
patch.object(processor, "_split_child_nodes", return_value=[]),
):
result = processor.transform(
documents,
process_rule={"mode": "custom", "rules": {"enabled": True}},
preview=True,
)
assert len(result) == 10
def test_transform_full_doc_mode_trims_children_for_preview(self, processor: ParentChildIndexProcessor) -> None:
docs = [
Document(page_content="first", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}),
Document(page_content="second", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}),
]
child_docs = [ChildDocument(page_content=f"child-{i}", metadata={}) for i in range(5)]
with (
patch(
"core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate",
return_value=self._full_doc_rules(),
),
patch.object(
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})]
),
patch.object(processor, "_split_child_nodes", return_value=child_docs),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.dify_config.CHILD_CHUNKS_PREVIEW_NUMBER",
2,
),
):
result = processor.transform(
docs,
process_rule={"mode": "hierarchical", "rules": {"enabled": True}},
preview=True,
)
assert len(result) == 1
assert len(result[0].children or []) == 2
assert result[0].attachments is not None
def test_load_creates_vectors_for_child_docs(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
parent_doc = Document(
page_content="parent",
metadata={},
children=[
ChildDocument(page_content="child-1", metadata={}),
ChildDocument(page_content="child-2", metadata={}),
],
)
multimodal_docs = [AttachmentDocument(page_content="image", metadata={})]
with patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls:
vector = mock_vector_cls.return_value
processor.load(dataset, [parent_doc], multimodal_documents=multimodal_docs)
assert vector.create.call_count == 1
formatted_docs = vector.create.call_args[0][0]
assert len(formatted_docs) == 2
assert all(isinstance(doc, Document) for doc in formatted_docs)
vector.create_multimodal.assert_called_once_with(multimodal_docs)
def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
delete_query = Mock()
where_query = Mock()
where_query.delete.return_value = 2
session = Mock()
session.query.return_value.where.return_value = where_query
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
):
vector = mock_vector_cls.return_value
processor.clean(
dataset,
["node-1"],
delete_child_chunks=True,
precomputed_child_node_ids=["child-1", "child-2"],
)
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
where_query.delete.assert_called_once_with(synchronize_session=False)
session.commit.assert_called_once()
def test_clean_queries_child_ids_when_not_precomputed(
self, processor: ParentChildIndexProcessor, dataset: Mock
) -> None:
child_query = Mock()
child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)]
session = Mock()
session.query.return_value = child_query
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
):
vector = mock_vector_cls.return_value
processor.clean(dataset, ["node-1"], delete_child_chunks=False)
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
where_query = Mock()
where_query.delete.return_value = 3
session = Mock()
session.query.return_value.where.return_value = where_query
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
):
vector = mock_vector_cls.return_value
processor.clean(dataset, None, delete_child_chunks=True)
vector.delete.assert_called_once()
where_query.delete.assert_called_once_with(synchronize_session=False)
session.commit.assert_called_once()
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
segment_query = Mock()
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
session = Mock()
session.query.return_value = segment_query
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
with (
patch(
"core.rag.index_processor.processor.parent_child_index_processor.session_factory.create_session",
return_value=session_ctx,
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments"
) as mock_summary,
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"),
):
processor.clean(dataset, ["node-1"], delete_summaries=True, precomputed_child_node_ids=[])
mock_summary.assert_called_once_with(dataset, ["seg-1"])
def test_clean_deletes_all_summaries_when_node_ids_missing(
self, processor: ParentChildIndexProcessor, dataset: Mock
) -> None:
with (
patch(
"core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments"
) as mock_summary,
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"),
):
processor.clean(dataset, None, delete_summaries=True)
mock_summary.assert_called_once_with(dataset, None)
def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8)
low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2)
with patch(
"core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve"
) as mock_retrieve:
mock_retrieve.return_value = [ok_result, low_result]
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {})
assert len(docs) == 1
assert docs[0].page_content == "keep"
assert docs[0].metadata["score"] == 0.8
def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None:
rules = SimpleNamespace(subchunk_segmentation=None)
with pytest.raises(ValueError, match="No subchunk segmentation found"):
processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None)
def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None:
rules = SimpleNamespace(subchunk_segmentation=self._segmentation())
splitter = Mock()
splitter.split_documents.return_value = [
Document(page_content=".child-1", metadata={}),
Document(page_content=" ", metadata={}),
]
with (
patch.object(processor, "_get_splitter", return_value=splitter),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
return_value="hash",
),
):
child_docs = processor._split_child_nodes(
Document(page_content="parent", metadata={}), rules, "custom", None
)
assert len(child_docs) == 1
assert child_docs[0].page_content == "child-1"
assert child_docs[0].metadata["doc_hash"] == "hash"
def test_index_creates_process_rule_segments_and_vectors(
self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
parent_childs = SimpleNamespace(
parent_mode=ParentMode.PARAGRAPH,
parent_child_chunks=[
SimpleNamespace(
parent_content="parent text",
child_contents=["child-1", "child-2"],
files=[SimpleNamespace(id="file-1", filename="image.png")],
)
],
)
dataset_rule = SimpleNamespace(id="rule-1")
session = Mock()
with (
patch(
"core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate",
return_value=parent_childs,
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule",
return_value=dataset_rule,
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
side_effect=lambda text: f"hash-{text}",
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore"
) as mock_store_cls,
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
):
processor.index(dataset, dataset_document, {"parent_child_chunks": []})
assert dataset_document.dataset_process_rule_id == "rule-1"
session.add.assert_called_once_with(dataset_rule)
session.flush.assert_called_once()
session.commit.assert_called_once()
mock_store_cls.return_value.add_documents.assert_called_once()
assert mock_vector_cls.return_value.create.call_count == 1
mock_vector_cls.return_value.create_multimodal.assert_called_once()
def test_index_uses_content_files_when_files_missing(
self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
parent_childs = SimpleNamespace(
parent_mode=ParentMode.PARAGRAPH,
parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)],
)
dataset_rule = SimpleNamespace(id="rule-1")
session = Mock()
with (
patch(
"core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate",
return_value=parent_childs,
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule",
return_value=dataset_rule,
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user",
return_value=SimpleNamespace(id="user-1"),
),
patch.object(
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})]
) as mock_files,
patch("core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore"),
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"),
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
):
processor.index(dataset, dataset_document, {"parent_child_chunks": []})
mock_files.assert_called_once()
def test_index_raises_when_account_missing(
self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
parent_childs = SimpleNamespace(
parent_mode=ParentMode.PARAGRAPH,
parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)],
)
with (
patch(
"core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate",
return_value=parent_childs,
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
return_value="hash",
),
patch(
"core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user",
return_value=None,
),
):
with pytest.raises(ValueError, match="Invalid account"):
processor.index(dataset, dataset_document, {"parent_child_chunks": []})
def test_format_preview_returns_parent_child_structure(self, processor: ParentChildIndexProcessor) -> None:
parent_childs = SimpleNamespace(
parent_mode=ParentMode.PARAGRAPH,
parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child-1", "child-2"])],
)
with patch(
"core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate",
return_value=parent_childs,
):
preview = processor.format_preview({"parent_child_chunks": []})
assert preview["chunk_structure"] == "hierarchical_model"
assert preview["parent_mode"] == ParentMode.PARAGRAPH
assert preview["total_segments"] == 1
def test_generate_summary_preview_sets_summaries(self, processor: ParentChildIndexProcessor) -> None:
preview_texts = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")]
with patch(
"core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary",
return_value=("summary", None),
):
result = processor.generate_summary_preview(
"tenant-1", preview_texts, {"enable": True}, doc_language="English"
)
assert all(item.summary == "summary" for item in result)
def test_generate_summary_preview_raises_when_worker_fails(self, processor: ParentChildIndexProcessor) -> None:
preview_texts = [PreviewDetail(content="chunk-1")]
with patch(
"core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary",
side_effect=RuntimeError("summary failed"),
):
with pytest.raises(ValueError, match="Failed to generate summaries"):
processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True})
def test_generate_summary_preview_falls_back_without_flask_context(
self, processor: ParentChildIndexProcessor
) -> None:
preview_texts = [PreviewDetail(content="chunk-1")]
fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app")))
with (
patch("flask.current_app", fake_current_app),
patch(
"core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary",
return_value=("summary", None),
),
):
result = processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True})
assert result[0].summary == "summary"
def test_generate_summary_preview_handles_timeout(
self, processor: ParentChildIndexProcessor, fake_executor_cls: type
) -> None:
preview_texts = [PreviewDetail(content="chunk-1")]
future = Mock()
executor = fake_executor_cls(future)
with (
patch("concurrent.futures.ThreadPoolExecutor", return_value=executor),
patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]),
):
with pytest.raises(ValueError, match="timeout"):
processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True})
future.cancel.assert_called_once()

View File

@ -0,0 +1,382 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
from werkzeug.datastructures import FileStorage
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
from core.rag.models.document import AttachmentDocument, Document
class _ImmediateThread:
def __init__(self, target, args=(), kwargs=None):
self._target = target
self._args = args
self._kwargs = kwargs or {}
def start(self) -> None:
self._target(*self._args, **self._kwargs)
def join(self) -> None:
return None
class TestQAIndexProcessor:
@pytest.fixture
def processor(self) -> QAIndexProcessor:
return QAIndexProcessor()
@pytest.fixture
def dataset(self) -> Mock:
dataset = Mock()
dataset.id = "dataset-1"
dataset.tenant_id = "tenant-1"
dataset.indexing_technique = "high_quality"
dataset.is_multimodal = True
return dataset
@pytest.fixture
def dataset_document(self) -> Mock:
document = Mock()
document.id = "doc-1"
document.created_by = "user-1"
return document
@pytest.fixture
def process_rule(self) -> dict:
return {
"mode": "custom",
"rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}},
}
def _rules(self) -> SimpleNamespace:
segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n")
return SimpleNamespace(segmentation=segmentation)
def test_extract_forwards_automatic_flag(self, processor: QAIndexProcessor) -> None:
extract_setting = Mock()
expected_docs = [Document(page_content="chunk", metadata={})]
with patch("core.rag.index_processor.processor.qa_index_processor.ExtractProcessor.extract") as mock_extract:
mock_extract.return_value = expected_docs
docs = processor.extract(extract_setting, process_rule_mode="automatic")
assert docs == expected_docs
mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True)
def test_transform_rejects_none_process_rule(self, processor: QAIndexProcessor) -> None:
with pytest.raises(ValueError, match="No process rule found"):
processor.transform([Document(page_content="text", metadata={})], process_rule=None)
def test_transform_rejects_missing_rules_key(self, processor: QAIndexProcessor) -> None:
with pytest.raises(ValueError, match="No rules found in process rule"):
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"})
def test_transform_preview_calls_formatter_once(
self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app
) -> None:
document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"})
split_node = Document(page_content=".question", metadata={})
splitter = Mock()
splitter.split_documents.return_value = [split_node]
def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language):
all_qa_documents.append(Document(page_content="Q1", metadata={"answer": "A1"}))
with (
patch(
"core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules()
),
patch.object(processor, "_get_splitter", return_value=splitter),
patch(
"core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", return_value="clean text"
),
patch(
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
),
patch(
"core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols",
side_effect=lambda text: text.lstrip("."),
),
patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format,
patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app,
):
mock_current_app._get_current_object.return_value = fake_flask_app
result = processor.transform(
[document],
process_rule=process_rule,
preview=True,
tenant_id="tenant-1",
doc_language="English",
)
assert len(result) == 1
assert result[0].metadata["answer"] == "A1"
mock_format.assert_called_once()
def test_transform_non_preview_uses_thread_batches(
self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app
) -> None:
documents = [
Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}),
Document(page_content="doc-2", metadata={"document_id": "doc-2", "dataset_id": "dataset-1"}),
]
split_node = Document(page_content="question", metadata={})
splitter = Mock()
splitter.split_documents.return_value = [split_node]
def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language):
all_qa_documents.append(Document(page_content=f"Q-{document_node.page_content}", metadata={"answer": "A"}))
with (
patch(
"core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules()
),
patch.object(processor, "_get_splitter", return_value=splitter),
patch(
"core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean",
side_effect=lambda text, _: text,
),
patch(
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
),
patch(
"core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols",
side_effect=lambda text: text,
),
patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format,
patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app,
patch(
"core.rag.index_processor.processor.qa_index_processor.threading.Thread", side_effect=_ImmediateThread
),
):
mock_current_app._get_current_object.return_value = fake_flask_app
result = processor.transform(documents, process_rule=process_rule, preview=False, tenant_id="tenant-1")
assert len(result) == 2
assert mock_format.call_count == 2
def test_format_by_template_validates_file_type(self, processor: QAIndexProcessor) -> None:
not_csv_file = Mock(spec=FileStorage)
not_csv_file.filename = "qa.txt"
with pytest.raises(ValueError, match="Only CSV files"):
processor.format_by_template(not_csv_file)
def test_format_by_template_parses_csv_rows(self, processor: QAIndexProcessor) -> None:
csv_file = Mock(spec=FileStorage)
csv_file.filename = "qa.csv"
dataframe = pd.DataFrame([["Q1", "A1"], ["Q2", "A2"]])
with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=dataframe):
docs = processor.format_by_template(csv_file)
assert [doc.page_content for doc in docs] == ["Q1", "Q2"]
assert [doc.metadata["answer"] for doc in docs] == ["A1", "A2"]
def test_format_by_template_raises_on_empty_csv(self, processor: QAIndexProcessor) -> None:
csv_file = Mock(spec=FileStorage)
csv_file.filename = "qa.csv"
with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=pd.DataFrame()):
with pytest.raises(ValueError, match="empty"):
processor.format_by_template(csv_file)
def test_format_by_template_raises_on_invalid_csv(self, processor: QAIndexProcessor) -> None:
csv_file = Mock(spec=FileStorage)
csv_file.filename = "qa.csv"
with patch(
"core.rag.index_processor.processor.qa_index_processor.pd.read_csv", side_effect=Exception("bad csv")
):
with pytest.raises(ValueError, match="bad csv"):
processor.format_by_template(csv_file)
def test_load_creates_vectors_for_high_quality_dataset(self, processor: QAIndexProcessor, dataset: Mock) -> None:
docs = [Document(page_content="Q1", metadata={"answer": "A1"})]
multimodal_docs = [AttachmentDocument(page_content="image", metadata={})]
with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls:
vector = mock_vector_cls.return_value
processor.load(dataset, docs, multimodal_documents=multimodal_docs)
vector.create.assert_called_once_with(docs)
vector.create_multimodal.assert_called_once_with(multimodal_docs)
def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None:
dataset.indexing_technique = "economy"
docs = [Document(page_content="Q1", metadata={"answer": "A1"})]
with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls:
processor.load(dataset, docs)
mock_vector_cls.assert_not_called()
def test_clean_handles_summary_deletion_and_vector_cleanup(
self, processor: QAIndexProcessor, dataset: Mock
) -> None:
mock_segment = SimpleNamespace(id="seg-1")
mock_query = Mock()
mock_query.filter.return_value.all.return_value = [mock_segment]
mock_session = Mock()
mock_session.query.return_value = mock_query
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False
with (
patch(
"core.rag.index_processor.processor.qa_index_processor.session_factory.create_session",
return_value=session_context,
),
patch(
"core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments"
) as mock_summary,
patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls,
):
vector = mock_vector_cls.return_value
processor.clean(dataset, ["node-1"], delete_summaries=True)
mock_summary.assert_called_once_with(dataset, ["seg-1"])
vector.delete_by_ids.assert_called_once_with(["node-1"])
def test_clean_handles_dataset_wide_cleanup(self, processor: QAIndexProcessor, dataset: Mock) -> None:
with (
patch(
"core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments"
) as mock_summary,
patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls,
):
vector = mock_vector_cls.return_value
processor.clean(dataset, None, delete_summaries=True)
mock_summary.assert_called_once_with(dataset, None)
vector.delete.assert_called_once()
def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None:
result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9)
result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1)
with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve:
mock_retrieve.return_value = [result_ok, result_low]
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
assert len(docs) == 1
assert docs[0].page_content == "accepted"
assert docs[0].metadata["score"] == 0.9
def test_index_adds_documents_and_vectors_for_high_quality(
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
qa_chunks = SimpleNamespace(
qa_chunks=[
SimpleNamespace(question="Q1", answer="A1"),
SimpleNamespace(question="Q2", answer="A2"),
]
)
with (
patch(
"core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate",
return_value=qa_chunks,
),
patch(
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
),
patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore") as mock_store_cls,
patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls,
):
processor.index(dataset, dataset_document, {"qa_chunks": []})
mock_store_cls.return_value.add_documents.assert_called_once()
mock_vector_cls.return_value.create.assert_called_once()
def test_index_requires_high_quality(
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
dataset.indexing_technique = "economy"
qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")])
with (
patch(
"core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate",
return_value=qa_chunks,
),
patch(
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
),
patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore"),
):
with pytest.raises(ValueError, match="must be high quality"):
processor.index(dataset, dataset_document, {"qa_chunks": []})
def test_format_preview_returns_qa_preview(self, processor: QAIndexProcessor) -> None:
qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")])
with patch(
"core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate",
return_value=qa_chunks,
):
preview = processor.format_preview({"qa_chunks": []})
assert preview["chunk_structure"] == "qa_model"
assert preview["total_segments"] == 1
assert preview["qa_preview"] == [{"question": "Q1", "answer": "A1"}]
def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None:
preview_items = [PreviewDetail(content="Q1")]
assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items
def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None:
all_qa_documents: list[Document] = []
blank_document = Document(page_content=" ", metadata={})
processor._format_qa_document(fake_flask_app, "tenant-1", blank_document, all_qa_documents, "English")
assert all_qa_documents == []
def test_format_qa_document_builds_question_answer_documents(
self, processor: QAIndexProcessor, fake_flask_app
) -> None:
all_qa_documents: list[Document] = []
source_document = Document(page_content="source text", metadata={"origin": "doc-1"})
with (
patch(
"core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document",
return_value="Q1: What is this?\nA1: A test.\nQ2: Why?\nA2: Coverage.",
),
patch(
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
),
):
processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English")
assert len(all_qa_documents) == 2
assert all_qa_documents[0].page_content == "What is this?"
assert all_qa_documents[0].metadata["answer"] == "A test."
assert all_qa_documents[1].metadata["answer"] == "Coverage."
def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None:
all_qa_documents: list[Document] = []
source_document = Document(page_content="source text", metadata={"origin": "doc-1"})
with (
patch(
"core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document",
side_effect=RuntimeError("llm failure"),
),
patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger,
):
processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English")
assert all_qa_documents == []
mock_logger.exception.assert_called_once_with("Failed to format qa document")
def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None:
parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n")
assert parsed == [{"question": "First?", "answer": "One."}, {"question": "Second?", "answer": "Two."}]

View File

@ -0,0 +1,291 @@
from types import SimpleNamespace
from unittest.mock import Mock, patch
import httpx
import pytest
from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document
class _ForwardingBaseIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting, **kwargs):
return super().extract(extract_setting, **kwargs)
def transform(self, documents, current_user=None, **kwargs):
return super().transform(documents, current_user=current_user, **kwargs)
def generate_summary_preview(self, tenant_id, preview_texts, summary_index_setting, doc_language=None):
return super().generate_summary_preview(
tenant_id=tenant_id,
preview_texts=preview_texts,
summary_index_setting=summary_index_setting,
doc_language=doc_language,
)
def load(self, dataset, documents, multimodal_documents=None, with_keywords=True, **kwargs):
return super().load(
dataset=dataset,
documents=documents,
multimodal_documents=multimodal_documents,
with_keywords=with_keywords,
**kwargs,
)
def clean(self, dataset, node_ids, with_keywords=True, **kwargs):
return super().clean(dataset=dataset, node_ids=node_ids, with_keywords=with_keywords, **kwargs)
def index(self, dataset, document, chunks):
return super().index(dataset=dataset, document=document, chunks=chunks)
def format_preview(self, chunks):
return super().format_preview(chunks)
def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model):
return super().retrieve(
retrieval_method=retrieval_method,
query=query,
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
class TestBaseIndexProcessor:
@pytest.fixture
def processor(self) -> _ForwardingBaseIndexProcessor:
return _ForwardingBaseIndexProcessor()
def test_abstract_methods_raise_not_implemented(self, processor: _ForwardingBaseIndexProcessor) -> None:
with pytest.raises(NotImplementedError):
processor.extract(Mock())
with pytest.raises(NotImplementedError):
processor.transform([])
with pytest.raises(NotImplementedError):
processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {})
with pytest.raises(NotImplementedError):
processor.load(Mock(), [])
with pytest.raises(NotImplementedError):
processor.clean(Mock(), None)
with pytest.raises(NotImplementedError):
processor.index(Mock(), Mock(), {})
with pytest.raises(NotImplementedError):
processor.format_preview([])
with pytest.raises(NotImplementedError):
processor.retrieve("semantic_search", "q", Mock(), 3, 0.5, {})
def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None:
with patch(
"core.rag.index_processor.index_processor_base.dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH", 1000
):
with pytest.raises(ValueError, match="between 50 and 1000"):
processor._get_splitter("custom", 49, 0, "", None)
with pytest.raises(ValueError, match="between 50 and 1000"):
processor._get_splitter("custom", 1001, 0, "", None)
def test_get_splitter_custom_mode_uses_fixed_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None:
fixed_splitter = Mock()
with patch(
"core.rag.index_processor.index_processor_base.FixedRecursiveCharacterTextSplitter.from_encoder",
return_value=fixed_splitter,
) as mock_fixed:
splitter = processor._get_splitter("hierarchical", 120, 10, "\\n\\n", None)
assert splitter is fixed_splitter
assert mock_fixed.call_args.kwargs["fixed_separator"] == "\n\n"
assert mock_fixed.call_args.kwargs["chunk_size"] == 120
def test_get_splitter_automatic_mode_uses_enhance_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None:
auto_splitter = Mock()
with patch(
"core.rag.index_processor.index_processor_base.EnhanceRecursiveCharacterTextSplitter.from_encoder",
return_value=auto_splitter,
) as mock_enhance:
splitter = processor._get_splitter("automatic", 0, 0, "", None)
assert splitter is auto_splitter
assert "chunk_size" in mock_enhance.call_args.kwargs
def test_extract_markdown_images(self, processor: _ForwardingBaseIndexProcessor) -> None:
markdown = "text ![a](https://a/img.png) and ![b](/files/123/file-preview)"
images = processor._extract_markdown_images(markdown)
assert images == ["https://a/img.png", "/files/123/file-preview"]
def test_get_content_files_without_images_returns_empty(self, processor: _ForwardingBaseIndexProcessor) -> None:
document = Document(page_content="no image markdown", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
assert processor._get_content_files(document) == []
def test_get_content_files_handles_all_sources_and_duplicates(
self, processor: _ForwardingBaseIndexProcessor
) -> None:
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
images = [
"/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview",
"/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview",
"/files/bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb/file-preview",
"/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png",
"https://example.com/remote.png?x=1",
]
upload_a = SimpleNamespace(id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", name="a.png")
upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png")
upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png")
upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png")
db_query = Mock()
db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
db_session = Mock()
db_session.query.return_value = db_query
with (
patch.object(processor, "_extract_markdown_images", return_value=images),
patch.object(processor, "_download_tool_file", return_value="tool-upload-id") as mock_tool_download,
patch.object(processor, "_download_image", return_value="remote-upload-id") as mock_image_download,
patch("core.rag.index_processor.index_processor_base.db.session", db_session),
):
files = processor._get_content_files(document, current_user=Mock())
assert len(files) == 5
assert all(isinstance(file, AttachmentDocument) for file in files)
assert files[0].metadata["doc_type"] == DocType.IMAGE
assert files[0].metadata["document_id"] == "doc-1"
assert files[0].metadata["dataset_id"] == "ds-1"
assert files[0].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
assert files[1].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock_tool_download.assert_called_once()
mock_image_download.assert_called_once()
def test_get_content_files_skips_tool_and_remote_download_without_user(
self, processor: _ForwardingBaseIndexProcessor
) -> None:
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
images = ["/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", "https://example.com/remote.png"]
with patch.object(processor, "_extract_markdown_images", return_value=images):
files = processor._get_content_files(document, current_user=None)
assert files == []
def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None:
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"]
db_query = Mock()
db_query.where.return_value.all.return_value = []
db_session = Mock()
db_session.query.return_value = db_query
with (
patch.object(processor, "_extract_markdown_images", return_value=images),
patch("core.rag.index_processor.index_processor_base.db.session", db_session),
):
files = processor._get_content_files(document)
assert files == []
def test_download_image_success_with_filename_from_content_disposition(
self, processor: _ForwardingBaseIndexProcessor
) -> None:
response = Mock()
response.headers = {
"Content-Length": "4",
"content-disposition": "attachment; filename=test-image.png",
"content-type": "image/png",
}
response.raise_for_status.return_value = None
response.iter_bytes.return_value = [b"data"]
upload_result = SimpleNamespace(id="upload-id")
mock_db = Mock()
mock_db.engine = Mock()
with (
patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response),
patch("core.rag.index_processor.index_processor_base.db", mock_db),
patch("services.file_service.FileService") as mock_file_service,
):
mock_file_service.return_value.upload_file.return_value = upload_result
upload_id = processor._download_image("https://example.com/test.png", current_user=Mock())
assert upload_id == "upload-id"
mock_file_service.return_value.upload_file.assert_called_once()
def test_download_image_validates_size_and_empty_content(self, processor: _ForwardingBaseIndexProcessor) -> None:
too_large = Mock()
too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"}
too_large.raise_for_status.return_value = None
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large):
assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None
empty = Mock()
empty.headers = {"Content-Length": "0", "content-type": "image/png"}
empty.raise_for_status.return_value = None
empty.iter_bytes.return_value = []
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty):
assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None
def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None:
response = Mock()
response.headers = {"content-type": "image/png"}
response.raise_for_status.return_value = None
response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)]
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response):
assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None
def test_download_image_handles_timeout_request_and_unexpected_errors(
self, processor: _ForwardingBaseIndexProcessor
) -> None:
request = httpx.Request("GET", "https://example.com/image.png")
with patch(
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
side_effect=httpx.TimeoutException("timeout"),
):
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
with patch(
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
side_effect=httpx.RequestError("bad request", request=request),
):
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
with patch(
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
side_effect=RuntimeError("unexpected"),
):
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
db_query = Mock()
db_query.where.return_value.first.return_value = None
db_session = Mock()
db_session.query.return_value = db_query
with patch("core.rag.index_processor.index_processor_base.db.session", db_session):
assert processor._download_tool_file("tool-id", current_user=Mock()) is None
def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png")
db_query = Mock()
db_query.where.return_value.first.return_value = tool_file
db_session = Mock()
db_session.query.return_value = db_query
mock_db = Mock()
mock_db.session = db_session
mock_db.engine = Mock()
upload_result = SimpleNamespace(id="upload-id")
with (
patch("core.rag.index_processor.index_processor_base.db", mock_db),
patch("core.rag.index_processor.index_processor_base.storage.load_once", return_value=b"blob") as mock_load,
patch("services.file_service.FileService") as mock_file_service,
):
mock_file_service.return_value.upload_file.return_value = upload_result
result = processor._download_tool_file("tool-id", current_user=Mock())
assert result == "upload-id"
mock_load.assert_called_once_with("k1")
mock_file_service.return_value.upload_file.assert_called_once()

View File

@ -0,0 +1,42 @@
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
class TestIndexProcessorFactory:
def test_requires_index_type(self) -> None:
factory = IndexProcessorFactory(index_type=None)
with pytest.raises(ValueError, match="Index type must be specified"):
factory.init_index_processor()
def test_builds_paragraph_processor(self) -> None:
factory = IndexProcessorFactory(index_type=IndexStructureType.PARAGRAPH_INDEX)
processor = factory.init_index_processor()
assert isinstance(processor, ParagraphIndexProcessor)
def test_builds_qa_processor(self) -> None:
factory = IndexProcessorFactory(index_type=IndexStructureType.QA_INDEX)
processor = factory.init_index_processor()
assert isinstance(processor, QAIndexProcessor)
def test_builds_parent_child_processor(self) -> None:
factory = IndexProcessorFactory(index_type=IndexStructureType.PARENT_CHILD_INDEX)
processor = factory.init_index_processor()
assert isinstance(processor, ParentChildIndexProcessor)
def test_rejects_unsupported_index_type(self) -> None:
factory = IndexProcessorFactory(index_type="unsupported")
with pytest.raises(ValueError, match="is not supported"):
factory.init_index_processor()

View File

@ -12,13 +12,18 @@ All tests use mocking to avoid external dependencies and ensure fast, reliable e
Tests follow the Arrange-Act-Assert pattern for clarity.
"""
from operator import itemgetter
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.model_manager import ModelInstance
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
from core.rag.rerank.rerank_factory import RerankRunnerFactory
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.rerank.rerank_type import RerankMode
@ -26,7 +31,7 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner
from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
def create_mock_model_instance():
def create_mock_model_instance() -> ModelInstance:
"""Create a properly configured mock ModelInstance for reranking tests."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
@ -59,14 +64,7 @@ class TestRerankModelRunner:
@pytest.fixture
def mock_model_instance(self):
"""Create a mock ModelInstance for reranking."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model_name = "test-model"
return mock_instance
return create_mock_model_instance()
@pytest.fixture
def rerank_runner(self, mock_model_instance):
@ -382,6 +380,206 @@ class TestRerankModelRunner:
assert call_kwargs["user"] == "user123"
class _ForwardingBaseRerankRunner(BaseRerankRunner):
def run(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
return super().run(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=top_n,
user=user,
query_type=query_type,
)
class TestBaseRerankRunner:
def test_run_raises_not_implemented(self):
runner = _ForwardingBaseRerankRunner()
with pytest.raises(NotImplementedError):
runner.run(query="python", documents=[])
class TestRerankModelRunnerMultimodal:
@pytest.fixture
def mock_model_instance(self):
return create_mock_model_instance()
@pytest.fixture
def rerank_runner(self, mock_model_instance):
return RerankModelRunner(rerank_model_instance=mock_model_instance)
def test_run_returns_original_documents_for_non_text_query_without_vision_support(
self, rerank_runner, mock_model_instance
):
documents = [
Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"),
]
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY)
assert result == documents
mock_model_instance.invoke_rerank.assert_not_called()
def test_run_uses_multimodal_path_when_vision_support_is_enabled(self, rerank_runner):
documents = [
Document(page_content="doc", metadata={"doc_id": "doc1", "source": "wiki"}, provider="dify"),
]
rerank_result = RerankResult(
model="rerank-model",
docs=[RerankDocument(index=0, text="doc", score=0.88)],
)
with (
patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm,
patch.object(
rerank_runner,
"fetch_multimodal_rerank",
return_value=(rerank_result, documents),
) as mock_multimodal,
):
mock_mm.return_value.check_model_support_vision.return_value = True
result = rerank_runner.run(query="python", documents=documents, query_type=QueryType.TEXT_QUERY)
mock_multimodal.assert_called_once()
assert len(result) == 1
assert result[0].metadata["score"] == 0.88
def test_fetch_multimodal_rerank_builds_docs_and_calls_text_rerank(self, rerank_runner):
image_doc = Document(
page_content="image-content",
metadata={"doc_id": "img-1", "doc_type": DocType.IMAGE},
provider="dify",
)
text_doc = Document(
page_content="text-content",
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
provider="dify",
)
external_doc = Document(
page_content="external-content",
metadata={},
provider="external",
)
query = Mock()
query.where.return_value.first.return_value = SimpleNamespace(key="image-key")
rerank_result = RerankResult(model="rerank-model", docs=[])
with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once,
patch.object(
rerank_runner,
"fetch_text_rerank",
return_value=(rerank_result, [image_doc, text_doc, external_doc]),
) as mock_text_rerank,
):
result, unique_documents = rerank_runner.fetch_multimodal_rerank(
query="python",
documents=[image_doc, text_doc, external_doc, external_doc],
query_type=QueryType.TEXT_QUERY,
)
assert result == rerank_result
assert len(unique_documents) == 3
mock_load_once.assert_called_once_with("image-key")
text_rerank_call_args = mock_text_rerank.call_args.args
assert len(text_rerank_call_args[1]) == 3
def test_fetch_multimodal_rerank_skips_missing_image_upload(self, rerank_runner):
image_doc = Document(
page_content="image-content",
metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE},
provider="dify",
)
query = Mock()
query.where.return_value.first.return_value = None
rerank_result = RerankResult(model="rerank-model", docs=[])
with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
patch.object(
rerank_runner,
"fetch_text_rerank",
return_value=(rerank_result, [image_doc]),
) as mock_text_rerank,
):
result, unique_documents = rerank_runner.fetch_multimodal_rerank(
query="python",
documents=[image_doc],
query_type=QueryType.TEXT_QUERY,
)
assert result == rerank_result
assert unique_documents == [image_doc]
docs_arg = mock_text_rerank.call_args.args[1]
assert len(docs_arg) == 1
def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance):
text_doc = Document(
page_content="text-content",
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
provider="dify",
)
query_chain = Mock()
query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key")
rerank_result = RerankResult(
model="rerank-model",
docs=[RerankDocument(index=0, text="text-content", score=0.77)],
)
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
):
result, unique_documents = rerank_runner.fetch_multimodal_rerank(
query="query-upload-id",
documents=[text_doc],
score_threshold=0.2,
top_n=2,
user="user-1",
query_type=QueryType.IMAGE_QUERY,
)
assert result == rerank_result
assert unique_documents == [text_doc]
invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs
assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE
assert invoke_kwargs["docs"][0]["content"] == "text-content"
assert invoke_kwargs["user"] == "user-1"
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
query_chain = Mock()
query_chain.where.return_value.first.return_value = None
with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain):
with pytest.raises(ValueError, match="Upload file not found for query"):
rerank_runner.fetch_multimodal_rerank(
query="missing-upload-id",
documents=[],
query_type=QueryType.IMAGE_QUERY,
)
def test_fetch_multimodal_rerank_rejects_unsupported_query_type(self, rerank_runner):
with pytest.raises(ValueError, match="is not supported"):
rerank_runner.fetch_multimodal_rerank(
query="python",
documents=[],
query_type="unsupported_query_type",
)
class TestWeightRerankRunner:
"""Unit tests for WeightRerankRunner.
@ -512,34 +710,39 @@ class TestWeightRerankRunner:
- TF-IDF scores are calculated correctly
- Cosine similarity is computed for keyword vectors
"""
# Arrange: Create runner
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config)
# Mock keyword extraction with specific keywords
keyword_map = {
"python programming": ["python", "programming"],
"Python is a programming language": ["python", "programming", "language"],
"JavaScript for web development": ["javascript", "web"],
"Java object-oriented programming": ["java", "programming"],
}
mock_handler_instance = MagicMock()
mock_handler_instance.extract_keywords.side_effect = [
["python", "programming"], # query
["python", "programming", "language"], # doc1
["javascript", "web"], # doc2
["java", "programming"], # doc3
]
mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text]
mock_jieba_handler.return_value = mock_handler_instance
# Mock embedding
mock_embedding_instance = MagicMock()
mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance
mock_cache_instance = MagicMock()
mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4]
mock_cache_embedding.return_value = mock_cache_instance
# Act: Run reranking
query_scores = runner._calculate_keyword_score("python programming", sample_documents_with_vectors)
vector_scores = runner._calculate_cosine(
"tenant123", "python programming", sample_documents_with_vectors, weights_config.vector_setting
)
expected_scores = {
doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score)
for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores)
}
result = runner.run(query="python programming", documents=sample_documents_with_vectors)
# Assert: Keywords are extracted and scores are calculated
assert len(result) == 3
# Document 1 should have highest keyword score (matches both query terms)
# Document 3 should have medium score (matches one term)
# Document 2 should have lowest score (matches no terms)
expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)]
assert [doc.metadata["doc_id"] for doc in result] == expected_order
for doc in result:
doc_id = doc.metadata["doc_id"]
assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6)
def test_vector_score_calculation(
self,
@ -556,30 +759,42 @@ class TestWeightRerankRunner:
- Cosine similarity is calculated with document vectors
- Vector scores are properly normalized
"""
# Arrange: Create runner
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config)
# Mock keyword extraction
keyword_map = {
"test query": ["test"],
"Python is a programming language": ["python"],
"JavaScript for web development": ["javascript"],
"Java object-oriented programming": ["java"],
}
mock_handler_instance = MagicMock()
mock_handler_instance.extract_keywords.return_value = ["test"]
mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text]
mock_jieba_handler.return_value = mock_handler_instance
# Mock embedding model
mock_embedding_instance = MagicMock()
mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance
# Mock cache embedding with specific query vector
mock_cache_instance = MagicMock()
query_vector = [0.2, 0.3, 0.4, 0.5]
mock_cache_instance.embed_query.return_value = query_vector
mock_cache_embedding.return_value = mock_cache_instance
# Act: Run reranking
query_scores = runner._calculate_keyword_score("test query", sample_documents_with_vectors)
vector_scores = runner._calculate_cosine(
"tenant123", "test query", sample_documents_with_vectors, weights_config.vector_setting
)
expected_scores = {
doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score)
for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores)
}
result = runner.run(query="test query", documents=sample_documents_with_vectors)
# Assert: Vector scores are calculated
assert len(result) == 3
# Verify cosine similarity was computed (doc2 vector is closest to query vector)
expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)]
assert [doc.metadata["doc_id"] for doc in result] == expected_order
for doc in result:
doc_id = doc.metadata["doc_id"]
assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6)
def test_score_threshold_filtering_weighted(
self,
@ -742,28 +957,40 @@ class TestWeightRerankRunner:
- Keyword weight (0.4) is applied to keyword scores
- Combined score is the sum of weighted components
"""
# Arrange: Create runner with known weights
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config)
# Mock keyword extraction
keyword_map = {
"test": ["test"],
"Python is a programming language": ["python", "language"],
"JavaScript for web development": ["javascript", "web"],
"Java object-oriented programming": ["java", "programming"],
}
mock_handler_instance = MagicMock()
mock_handler_instance.extract_keywords.return_value = ["test"]
mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text]
mock_jieba_handler.return_value = mock_handler_instance
# Mock embedding
mock_embedding_instance = MagicMock()
mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance
mock_cache_instance = MagicMock()
mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4]
mock_cache_embedding.return_value = mock_cache_instance
# Act: Run reranking
query_scores = runner._calculate_keyword_score("test", sample_documents_with_vectors)
vector_scores = runner._calculate_cosine(
"tenant123", "test", sample_documents_with_vectors, weights_config.vector_setting
)
expected_scores = {
doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score)
for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores)
}
result = runner.run(query="test", documents=sample_documents_with_vectors)
# Assert: Scores are combined with weights
# Score = 0.6 * vector_score + 0.4 * keyword_score
assert len(result) == 3
assert all("score" in doc.metadata for doc in result)
expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)]
assert [doc.metadata["doc_id"] for doc in result] == expected_order
for doc in result:
doc_id = doc.metadata["doc_id"]
assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6)
def test_existing_vector_score_in_metadata(
self,
@ -778,7 +1005,6 @@ class TestWeightRerankRunner:
- If document already has a score in metadata, it's used
- Cosine similarity calculation is skipped for such documents
"""
# Arrange: Documents with pre-existing scores
documents = [
Document(
page_content="Content with existing score",
@ -790,24 +1016,29 @@ class TestWeightRerankRunner:
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config)
# Mock keyword extraction
keyword_map = {
"test": ["test"],
"Content with existing score": ["test"],
}
mock_handler_instance = MagicMock()
mock_handler_instance.extract_keywords.return_value = ["test"]
mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text]
mock_jieba_handler.return_value = mock_handler_instance
# Mock embedding
mock_embedding_instance = MagicMock()
mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance
mock_cache_instance = MagicMock()
mock_cache_instance.embed_query.return_value = [0.1, 0.2]
mock_cache_embedding.return_value = mock_cache_instance
# Act: Run reranking
query_scores = runner._calculate_keyword_score("test", documents)
vector_scores = runner._calculate_cosine("tenant123", "test", documents, weights_config.vector_setting)
expected_score = 0.6 * vector_scores[0] + 0.4 * query_scores[0]
result = runner.run(query="test", documents=documents)
# Assert: Existing score is used in calculation
assert len(result) == 1
# The final score should incorporate the existing score (0.95) with vector weight (0.6)
assert result[0].metadata["doc_id"] == "doc1"
assert result[0].metadata["score"] == pytest.approx(expected_score, rel=1e-6)
class TestRerankRunnerFactory:

View File

@ -1,873 +0,0 @@
"""
Unit tests for DatasetRetrieval.process_metadata_filter_func.
This module provides comprehensive test coverage for the process_metadata_filter_func
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
filter expressions based on metadata filtering conditions.
Conditions Tested:
==================
1. **String Conditions**: contains, not contains, start with, end with
2. **Equality Conditions**: is / =, is not /
3. **Null Conditions**: empty, not empty
4. **Numeric Comparisons**: before / <, after / >, / <=, / >=
5. **List Conditions**: in
6. **Edge Cases**: None values, different data types (str, int, float)
Test Architecture:
==================
- Direct instantiation of DatasetRetrieval
- Mocking of DatasetDocument model attributes
- Verification of SQLAlchemy filter expressions
- Follows Arrange-Act-Assert (AAA) pattern
Running Tests:
==============
# Run all tests in this module
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
# Run a specific test
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
TestProcessMetadataFilterFunc::test_contains_condition -v
"""
from unittest.mock import MagicMock
import pytest
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
class TestProcessMetadataFilterFunc:
"""
Comprehensive test suite for process_metadata_filter_func method.
This test class validates all metadata filtering conditions supported by
the DatasetRetrieval class, including string operations, numeric comparisons,
null checks, and list operations.
Method Signature:
==================
def process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
) -> list:
The method builds SQLAlchemy filter expressions by:
1. Validating value is not None (except for empty/not empty conditions)
2. Using DatasetDocument.doc_metadata JSON field operations
3. Adding appropriate SQLAlchemy expressions to the filters list
4. Returning the updated filters list
Mocking Strategy:
==================
- Mock DatasetDocument.doc_metadata to avoid database dependencies
- Verify filter expressions are created correctly
- Test with various data types (str, int, float, list)
"""
@pytest.fixture
def retrieval(self):
"""
Create a DatasetRetrieval instance for testing.
Returns:
DatasetRetrieval: Instance to test process_metadata_filter_func
"""
return DatasetRetrieval()
@pytest.fixture
def mock_doc_metadata(self):
"""
Mock the DatasetDocument.doc_metadata JSON field.
The method uses DatasetDocument.doc_metadata[metadata_name] to access
JSON fields. We mock this to avoid database dependencies.
Returns:
Mock: Mocked doc_metadata attribute
"""
mock_metadata_field = MagicMock()
# Create mock for string access
mock_string_access = MagicMock()
mock_string_access.like = MagicMock()
mock_string_access.notlike = MagicMock()
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
mock_string_access.in_ = MagicMock(return_value=MagicMock())
# Create mock for float access (for numeric comparisons)
mock_float_access = MagicMock()
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
# Create mock for null checks
mock_null_access = MagicMock()
mock_null_access.is_ = MagicMock(return_value=MagicMock())
mock_null_access.isnot = MagicMock(return_value=MagicMock())
# Setup __getitem__ to return appropriate mock based on usage
def getitem_side_effect(name):
if name in ["author", "title", "category"]:
return mock_string_access
elif name in ["year", "price", "rating"]:
return mock_float_access
else:
return mock_string_access
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
mock_metadata_field.as_string.return_value = mock_string_access
mock_metadata_field.as_float.return_value = mock_float_access
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
return mock_metadata_field
# ==================== String Condition Tests ====================
def test_contains_condition_string_value(self, retrieval):
"""
Test 'contains' condition with string value.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value% syntax
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = "John"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_contains_condition(self, retrieval):
"""
Test 'not contains' condition.
Verifies:
- Filters list is populated with NOT LIKE expression
- Pattern matching uses %value% syntax with negation
"""
filters = []
sequence = 0
condition = "not contains"
metadata_name = "title"
value = "banned"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_start_with_condition(self, retrieval):
"""
Test 'start with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses value% syntax
"""
filters = []
sequence = 0
condition = "start with"
metadata_name = "category"
value = "tech"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_end_with_condition(self, retrieval):
"""
Test 'end with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value syntax
"""
filters = []
sequence = 0
condition = "end with"
metadata_name = "filename"
value = ".pdf"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Equality Condition Tests ====================
def test_is_condition_with_string_value(self, retrieval):
"""
Test 'is' (=) condition with string value.
Verifies:
- Filters list is populated with equality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = "Jane Doe"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_equals_condition_with_string_value(self, retrieval):
"""
Test '=' condition with string value.
Verifies:
- Same behavior as 'is' condition
- String comparison is used
"""
filters = []
sequence = 0
condition = "="
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_int_value(self, retrieval):
"""
Test 'is' condition with integer value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_float_value(self, retrieval):
"""
Test 'is' condition with float value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "price"
value = 19.99
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_string_value(self, retrieval):
"""
Test 'is not' () condition with string value.
Verifies:
- Filters list is populated with inequality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "author"
value = "Unknown"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_equals_condition(self, retrieval):
"""
Test '' condition with string value.
Verifies:
- Same behavior as 'is not' condition
- Inequality expression is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "category"
value = "archived"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_numeric_value(self, retrieval):
"""
Test 'is not' condition with numeric value.
Verifies:
- Numeric inequality comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Null Condition Tests ====================
def test_empty_condition(self, retrieval):
"""
Test 'empty' condition (null check).
Verifies:
- Filters list is populated with IS NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "empty"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_empty_condition(self, retrieval):
"""
Test 'not empty' condition (not null check).
Verifies:
- Filters list is populated with IS NOT NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "not empty"
metadata_name = "description"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Numeric Comparison Tests ====================
def test_before_condition(self, retrieval):
"""
Test 'before' (<) condition.
Verifies:
- Filters list is populated with less than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "before"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_condition(self, retrieval):
"""
Test '<' condition.
Verifies:
- Same behavior as 'before' condition
- Less than expression is used
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "price"
value = 100.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_after_condition(self, retrieval):
"""
Test 'after' (>) condition.
Verifies:
- Filters list is populated with greater than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "after"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_condition(self, retrieval):
"""
Test '>' condition.
Verifies:
- Same behavior as 'after' condition
- Greater than expression is used
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with less than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "price"
value = 50.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_ascii(self, retrieval):
"""
Test '<=' condition.
Verifies:
- Same behavior as '' condition
- Less than or equal expression is used
"""
filters = []
sequence = 0
condition = "<="
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with greater than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "rating"
value = 3.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_ascii(self, retrieval):
"""
Test '>=' condition.
Verifies:
- Same behavior as '' condition
- Greater than or equal expression is used
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== List/In Condition Tests ====================
def test_in_condition_with_comma_separated_string(self, retrieval):
"""
Test 'in' condition with comma-separated string value.
Verifies:
- String is split into list
- Whitespace is trimmed from each value
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "tech, science, AI "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_list_value(self, retrieval):
"""
Test 'in' condition with list value.
Verifies:
- List is processed correctly
- None values are filtered out
- IN expression is created with valid values
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "tags"
value = ["python", "javascript", None, "golang"]
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_tuple_value(self, retrieval):
"""
Test 'in' condition with tuple value.
Verifies:
- Tuple is processed like a list
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ("tech", "science", "ai")
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_empty_string(self, retrieval):
"""
Test 'in' condition with empty string value.
Verifies:
- Empty string results in literal(False) filter
- No valid values to match
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# Verify it's a literal(False) expression
# This is a bit tricky to test without access to the actual expression
def test_in_condition_with_only_whitespace(self, retrieval):
"""
Test 'in' condition with whitespace-only string value.
Verifies:
- Whitespace-only string results in literal(False) filter
- All values are stripped and filtered out
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = " , , "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_single_string(self, retrieval):
"""
Test 'in' condition with single non-comma string.
Verifies:
- Single string is treated as single-item list
- IN expression is created with one value
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Edge Case Tests ====================
def test_none_value_with_non_empty_condition(self, retrieval):
"""
Test None value with conditions that require value.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values (except empty/not empty)
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0 # No filter added
def test_none_value_with_equals_condition(self, retrieval):
"""
Test None value with 'is' (=) condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_none_value_with_numeric_condition(self, retrieval):
"""
Test None value with numeric comparison condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "year"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_existing_filters_preserved(self, retrieval):
"""
Test that existing filters are preserved.
Verifies:
- Existing filters in the list are not removed
- New filters are appended to the list
"""
existing_filter = MagicMock()
filters = [existing_filter]
sequence = 0
condition = "contains"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 2
assert filters[0] == existing_filter
def test_multiple_filters_accumulated(self, retrieval):
"""
Test multiple calls to accumulate filters.
Verifies:
- Each call adds a new filter to the list
- All filters are preserved across calls
"""
filters = []
# First filter
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
assert len(filters) == 1
# Second filter
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
assert len(filters) == 2
# Third filter
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
assert len(filters) == 3
def test_unknown_condition(self, retrieval):
"""
Test unknown/unsupported condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for unknown conditions
"""
filters = []
sequence = 0
condition = "unknown_condition"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_empty_string_value_with_contains(self, retrieval):
"""
Test empty string value with 'contains' condition.
Verifies:
- Filter is added even with empty string
- LIKE expression is created
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_special_characters_in_value(self, retrieval):
"""
Test special characters in value string.
Verifies:
- Special characters are handled in value
- LIKE expression is created correctly
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "title"
value = "C++ & Python's features"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_zero_value_with_numeric_condition(self, retrieval):
"""
Test zero value with numeric comparison condition.
Verifies:
- Zero is treated as valid value
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "price"
value = 0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_negative_value_with_numeric_condition(self, retrieval):
"""
Test negative value with numeric comparison condition.
Verifies:
- Negative numbers are handled correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "temperature"
value = -10.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_float_value_with_integer_comparison(self, retrieval):
"""
Test float value with numeric comparison condition.
Verifies:
- Float values work correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1

View File

@ -1,113 +0,0 @@
import threading
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from flask import Flask, current_app
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from models.dataset import Dataset
class TestRetrievalService:
@pytest.fixture
def mock_dataset(self) -> Dataset:
dataset = Mock(spec=Dataset)
dataset.id = str(uuid4())
dataset.tenant_id = str(uuid4())
dataset.name = "test_dataset"
dataset.indexing_technique = "high_quality"
dataset.provider = "dify"
return dataset
def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset):
"""
Repro test for current bug:
reranking runs after `with flask_app.app_context():` exits.
`_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`,
so we must assert from that list (not from an outer try/except).
"""
dataset_retrieval = DatasetRetrieval()
flask_app = Flask(__name__)
tenant_id = str(uuid4())
# second dataset to ensure dataset_count > 1 reranking branch
secondary_dataset = Mock(spec=Dataset)
secondary_dataset.id = str(uuid4())
secondary_dataset.provider = "dify"
secondary_dataset.indexing_technique = "high_quality"
# retriever returns 1 doc into internal list (all_documents_item)
document = Document(
page_content="Context aware doc",
metadata={
"doc_id": "doc1",
"score": 0.95,
"document_id": str(uuid4()),
"dataset_id": mock_dataset.id,
},
provider="dify",
)
def fake_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.append(document)
called = {"init": 0, "invoke": 0}
class ContextRequiredPostProcessor:
def __init__(self, *args, **kwargs):
called["init"] += 1
# will raise RuntimeError if no Flask app context exists
_ = current_app.name
def invoke(self, *args, **kwargs):
called["invoke"] += 1
_ = current_app.name
return kwargs.get("documents") or args[1]
# output list from _multiple_retrieve_thread
all_documents: list[Document] = []
# IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here
thread_exceptions: list[Exception] = []
def target():
with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever):
with patch(
"core.rag.retrieval.dataset_retrieval.DataPostProcessor",
ContextRequiredPostProcessor,
):
dataset_retrieval._multiple_retrieve_thread(
flask_app=flask_app,
available_datasets=[mock_dataset, secondary_dataset],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={
"reranking_provider_name": "cohere",
"reranking_model_name": "rerank-v2",
},
weights=None,
top_k=3,
score_threshold=0.0,
query="test query",
attachment_id=None,
dataset_count=2, # force reranking branch
thread_exceptions=thread_exceptions, # ✅ key
)
t = threading.Thread(target=target)
t.start()
t.join()
# Ensure reranking branch was actually executed
assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run."
# Current buggy code should record an exception (not raise it)
assert not thread_exceptions, thread_exceptions

View File

@ -0,0 +1,100 @@
from unittest.mock import Mock
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
class TestFunctionCallMultiDatasetRouter:
def test_invoke_returns_none_when_no_tools(self) -> None:
router = FunctionCallMultiDatasetRouter()
dataset_id, usage = router.invoke(
query="python",
dataset_tools=[],
model_config=Mock(),
model_instance=Mock(),
)
assert dataset_id is None
assert usage == LLMUsage.empty_usage()
def test_invoke_returns_single_tool_directly(self) -> None:
router = FunctionCallMultiDatasetRouter()
tool = Mock()
tool.name = "dataset-1"
dataset_id, usage = router.invoke(
query="python",
dataset_tools=[tool],
model_config=Mock(),
model_instance=Mock(),
)
assert dataset_id == "dataset-1"
assert usage == LLMUsage.empty_usage()
def test_invoke_returns_tool_from_model_response(self) -> None:
router = FunctionCallMultiDatasetRouter()
tool_1 = Mock()
tool_1.name = "dataset-1"
tool_2 = Mock()
tool_2.name = "dataset-2"
usage = LLMUsage.empty_usage()
response = Mock()
response.usage = usage
response.message.tool_calls = [Mock(function=Mock())]
response.message.tool_calls[0].function.name = "dataset-2"
model_instance = Mock()
model_instance.invoke_llm.return_value = response
dataset_id, returned_usage = router.invoke(
query="python",
dataset_tools=[tool_1, tool_2],
model_config=Mock(),
model_instance=model_instance,
)
assert dataset_id == "dataset-2"
assert returned_usage == usage
model_instance.invoke_llm.assert_called_once()
def test_invoke_returns_none_when_no_tool_calls(self) -> None:
router = FunctionCallMultiDatasetRouter()
response = Mock()
response.usage = LLMUsage.empty_usage()
response.message.tool_calls = []
model_instance = Mock()
model_instance.invoke_llm.return_value = response
tool_1 = Mock()
tool_1.name = "dataset-1"
tool_2 = Mock()
tool_2.name = "dataset-2"
dataset_id, usage = router.invoke(
query="python",
dataset_tools=[tool_1, tool_2],
model_config=Mock(),
model_instance=model_instance,
)
assert dataset_id is None
assert usage == response.usage
def test_invoke_returns_empty_usage_when_model_raises(self) -> None:
router = FunctionCallMultiDatasetRouter()
model_instance = Mock()
model_instance.invoke_llm.side_effect = RuntimeError("boom")
tool_1 = Mock()
tool_1.name = "dataset-1"
tool_2 = Mock()
tool_2.name = "dataset-2"
dataset_id, usage = router.invoke(
query="python",
dataset_tools=[tool_1, tool_2],
model_config=Mock(),
model_instance=model_instance,
)
assert dataset_id is None
assert usage == LLMUsage.empty_usage()

View File

@ -0,0 +1,252 @@
from types import SimpleNamespace
from unittest.mock import Mock, patch
from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
class TestReactMultiDatasetRouter:
def test_invoke_returns_none_when_no_tools(self) -> None:
router = ReactMultiDatasetRouter()
dataset_id, usage = router.invoke(
query="python",
dataset_tools=[],
model_config=Mock(),
model_instance=Mock(),
user_id="u1",
tenant_id="t1",
)
assert dataset_id is None
assert usage == LLMUsage.empty_usage()
def test_invoke_returns_single_tool_directly(self) -> None:
router = ReactMultiDatasetRouter()
tool = Mock()
tool.name = "dataset-1"
dataset_id, usage = router.invoke(
query="python",
dataset_tools=[tool],
model_config=Mock(),
model_instance=Mock(),
user_id="u1",
tenant_id="t1",
)
assert dataset_id == "dataset-1"
assert usage == LLMUsage.empty_usage()
def test_invoke_returns_tool_from_react_invoke(self) -> None:
router = ReactMultiDatasetRouter()
usage = LLMUsage.empty_usage()
tool_1 = Mock(name="dataset-1")
tool_1.name = "dataset-1"
tool_2 = Mock(name="dataset-2")
tool_2.name = "dataset-2"
with patch.object(router, "_react_invoke", return_value=("dataset-2", usage)) as mock_react:
dataset_id, returned_usage = router.invoke(
query="python",
dataset_tools=[tool_1, tool_2],
model_config=Mock(),
model_instance=Mock(),
user_id="u1",
tenant_id="t1",
)
mock_react.assert_called_once()
assert dataset_id == "dataset-2"
assert returned_usage == usage
def test_invoke_handles_react_invoke_errors(self) -> None:
router = ReactMultiDatasetRouter()
tool_1 = Mock()
tool_1.name = "dataset-1"
tool_2 = Mock()
tool_2.name = "dataset-2"
with patch.object(router, "_react_invoke", side_effect=RuntimeError("boom")):
dataset_id, usage = router.invoke(
query="python",
dataset_tools=[tool_1, tool_2],
model_config=Mock(),
model_instance=Mock(),
user_id="u1",
tenant_id="t1",
)
assert dataset_id is None
assert usage == LLMUsage.empty_usage()
def test_react_invoke_returns_action_tool(self) -> None:
router = ReactMultiDatasetRouter()
model_config = Mock()
model_config.mode = "chat"
model_config.parameters = {"temperature": 0.1}
usage = LLMUsage.empty_usage()
tools = [Mock(name="dataset-1"), Mock(name="dataset-2")]
tools[0].name = "dataset-1"
tools[0].description = "desc"
tools[1].name = "dataset-2"
tools[1].description = "desc"
with (
patch.object(router, "create_chat_prompt", return_value=[Mock()]) as mock_chat_prompt,
patch(
"core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform"
) as mock_prompt_transform,
patch.object(router, "_invoke_llm", return_value=('{"action":"dataset-2","action_input":{}}', usage)),
patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls,
):
mock_prompt_transform.return_value.get_prompt.return_value = [Mock()]
mock_parser_cls.return_value.parse.return_value = ReactAction("dataset-2", {}, "log")
dataset_id, returned_usage = router._react_invoke(
query="python",
model_config=model_config,
model_instance=Mock(),
tools=tools,
user_id="u1",
tenant_id="t1",
)
mock_chat_prompt.assert_called_once()
assert dataset_id == "dataset-2"
assert returned_usage == usage
def test_react_invoke_returns_none_for_finish(self) -> None:
router = ReactMultiDatasetRouter()
model_config = Mock()
model_config.mode = "completion"
model_config.parameters = {"temperature": 0.1}
usage = LLMUsage.empty_usage()
tool = Mock()
tool.name = "dataset-1"
tool.description = "desc"
with (
patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt,
patch(
"core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform"
) as mock_prompt_transform,
patch.object(
router, "_invoke_llm", return_value=('{"action":"Final Answer","action_input":"done"}', usage)
),
patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls,
):
mock_prompt_transform.return_value.get_prompt.return_value = [Mock()]
mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log")
dataset_id, returned_usage = router._react_invoke(
query="python",
model_config=model_config,
model_instance=Mock(),
tools=[tool],
user_id="u1",
tenant_id="t1",
)
mock_completion_prompt.assert_called_once()
assert dataset_id is None
assert returned_usage == usage
def test_invoke_llm_and_handle_result(self) -> None:
router = ReactMultiDatasetRouter()
usage = LLMUsage.empty_usage()
delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=usage)
chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta)
model_instance = Mock()
model_instance.invoke_llm.return_value = iter([chunk])
with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct:
text, returned_usage = router._invoke_llm(
completion_param={"temperature": 0.1},
model_instance=model_instance,
prompt_messages=[Mock()],
stop=["Observation:"],
user_id="u1",
tenant_id="t1",
)
assert text == "part"
assert returned_usage == usage
mock_deduct.assert_called_once()
def test_handle_invoke_result_with_empty_usage(self) -> None:
router = ReactMultiDatasetRouter()
delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=None)
chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta)
text, usage = router._handle_invoke_result(iter([chunk]))
assert text == "part"
assert usage == LLMUsage.empty_usage()
def test_create_chat_prompt(self) -> None:
router = ReactMultiDatasetRouter()
tool_1 = Mock()
tool_1.name = "dataset-1"
tool_1.description = "d1"
tool_2 = Mock()
tool_2.name = "dataset-2"
tool_2.description = "d2"
chat_prompt = router.create_chat_prompt(query="python", tools=[tool_1, tool_2])
assert len(chat_prompt) == 2
assert chat_prompt[0].role == PromptMessageRole.SYSTEM
assert chat_prompt[1].role == PromptMessageRole.USER
assert "dataset-1" in chat_prompt[0].text
assert "dataset-2" in chat_prompt[0].text
def test_create_completion_prompt(self) -> None:
router = ReactMultiDatasetRouter()
tool_1 = Mock()
tool_1.name = "dataset-1"
tool_1.description = "d1"
tool_2 = Mock()
tool_2.name = "dataset-2"
tool_2.description = "d2"
completion_prompt = router.create_completion_prompt(tools=[tool_1, tool_2])
assert "dataset-1: d1" in completion_prompt.text
assert "dataset-2: d2" in completion_prompt.text
def test_react_invoke_uses_completion_branch_for_non_chat_mode(self) -> None:
router = ReactMultiDatasetRouter()
model_config = Mock()
model_config.mode = "unknown-mode"
model_config.parameters = {}
tool = Mock()
tool.name = "dataset-1"
tool.description = "desc"
with (
patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt,
patch(
"core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform"
) as mock_prompt_transform,
patch.object(
router,
"_invoke_llm",
return_value=('{"action":"Final Answer","action_input":"done"}', LLMUsage.empty_usage()),
),
patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls,
):
mock_prompt_transform.return_value.get_prompt.return_value = [Mock()]
mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log")
dataset_id, usage = router._react_invoke(
query="python",
model_config=model_config,
model_instance=Mock(),
tools=[tool],
user_id="u1",
tenant_id="t1",
)
mock_completion_prompt.assert_called_once()
assert dataset_id is None
assert usage == LLMUsage.empty_usage()

View File

@ -0,0 +1,69 @@
import pytest
from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
class TestStructuredChatOutputParser:
def test_parse_action_without_action_input(self) -> None:
parser = StructuredChatOutputParser()
text = 'Action:\n```json\n{"action":"some_action"}\n```'
result = parser.parse(text)
assert isinstance(result, ReactAction)
assert result.tool == "some_action"
assert result.tool_input == {}
def test_parse_json_without_action_key(self) -> None:
parser = StructuredChatOutputParser()
text = 'Action:\n```json\n{"not_action":"search"}\n```'
with pytest.raises(ValueError, match="Could not parse LLM output"):
parser.parse(text)
def test_parse_returns_action_for_tool_call(self) -> None:
parser = StructuredChatOutputParser()
text = (
'Thought: call tool\nAction:\n```json\n{"action":"search_dataset","action_input":{"query":"python"}}\n```'
)
result = parser.parse(text)
assert isinstance(result, ReactAction)
assert result.tool == "search_dataset"
assert result.tool_input == {"query": "python"}
assert result.log == text
def test_parse_returns_finish_for_final_answer(self) -> None:
parser = StructuredChatOutputParser()
text = 'Thought: done\nAction:\n```json\n{"action":"Final Answer","action_input":"final text"}\n```'
result = parser.parse(text)
assert isinstance(result, ReactFinish)
assert result.return_values == {"output": "final text"}
assert result.log == text
def test_parse_returns_finish_for_json_array_payload(self) -> None:
parser = StructuredChatOutputParser()
text = 'Action:\n```json\n[{"action":"search","action_input":"hello"}]\n```'
result = parser.parse(text)
assert isinstance(result, ReactFinish)
assert result.return_values == {"output": text}
assert result.log == text
def test_parse_returns_finish_for_plain_text(self) -> None:
parser = StructuredChatOutputParser()
text = "No structured action block"
result = parser.parse(text)
assert isinstance(result, ReactFinish)
assert result.return_values == {"output": text}
def test_parse_raises_value_error_for_invalid_json(self) -> None:
parser = StructuredChatOutputParser()
text = 'Action:\n```json\n{"action":"search","action_input": }\n```'
with pytest.raises(ValueError, match="Could not parse LLM output"):
parser.parse(text)

View File

@ -125,7 +125,11 @@ Run with coverage:
- Tests are organized by functionality in classes for better organization
"""
import asyncio
import string
import sys
import types
from inspect import currentframe
from unittest.mock import Mock, patch
import pytest
@ -604,6 +608,51 @@ class TestRecursiveCharacterTextSplitter:
assert "def hello_world" in combined or "hello_world" in combined
class TestTextSplitterBasePaths:
"""Target uncovered base TextSplitter paths."""
def test_from_huggingface_tokenizer_success_path(self):
"""Cover from_huggingface_tokenizer success branch with mocked transformers."""
class _FakePreTrainedTokenizerBase:
pass
class _FakeTokenizer(_FakePreTrainedTokenizerBase):
def encode(self, text: str):
return [ord(c) for c in text]
fake_transformers = types.SimpleNamespace(PreTrainedTokenizerBase=_FakePreTrainedTokenizerBase)
with patch.dict(sys.modules, {"transformers": fake_transformers}):
splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer=_FakeTokenizer(),
chunk_size=5,
chunk_overlap=1,
)
chunks = splitter.split_text("abcdef")
assert chunks
def test_from_huggingface_tokenizer_import_error(self):
"""Cover from_huggingface_tokenizer import-error branch."""
with patch.dict(sys.modules, {"transformers": None}):
with pytest.raises(ValueError, match="Could not import transformers"):
RecursiveCharacterTextSplitter.from_huggingface_tokenizer(tokenizer=object(), chunk_size=5)
def test_atransform_documents_raises_not_implemented(self):
"""Cover atransform_documents NotImplemented branch."""
splitter = RecursiveCharacterTextSplitter(chunk_size=20, chunk_overlap=5)
with pytest.raises(NotImplementedError):
asyncio.run(splitter.atransform_documents([Document(page_content="x", metadata={})]))
def test_merge_splits_logs_warning_for_oversized_total(self):
"""Cover logger.warning path in _merge_splits."""
splitter = RecursiveCharacterTextSplitter(chunk_size=5, chunk_overlap=1)
with patch("core.rag.splitter.text_splitter.logger.warning") as mock_warning:
merged = splitter._merge_splits(["abcdefghij", "b"], "", [10, 1])
assert merged
mock_warning.assert_called_once()
# ============================================================================
# Test TokenTextSplitter
# ============================================================================
@ -662,6 +711,44 @@ class TestTokenTextSplitter:
except ImportError:
pytest.skip("tiktoken not installed")
def test_initialization_and_split_with_mocked_tiktoken_encoding(self):
"""Cover TokenTextSplitter __init__ else-path and split_text logic."""
class _FakeEncoding:
def encode(self, text: str, allowed_special=None, disallowed_special=None):
return [ord(c) for c in text]
def decode(self, token_ids: list[int]) -> str:
return "".join(chr(i) for i in token_ids)
fake_tiktoken = types.SimpleNamespace(get_encoding=lambda name: _FakeEncoding())
with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}):
splitter = TokenTextSplitter(encoding_name="gpt2", chunk_size=4, chunk_overlap=1)
result = splitter.split_text("abcdefgh")
assert result
assert all(isinstance(chunk, str) for chunk in result)
def test_initialization_with_model_name_uses_encoding_for_model(self):
"""Cover TokenTextSplitter model_name init branch."""
class _FakeEncoding:
def encode(self, text: str, allowed_special=None, disallowed_special=None):
return [ord(c) for c in text]
def decode(self, token_ids: list[int]) -> str:
return "".join(chr(i) for i in token_ids)
fake_encoding = _FakeEncoding()
fake_tiktoken = types.SimpleNamespace(
encoding_for_model=lambda model_name: fake_encoding,
get_encoding=lambda name: _FakeEncoding(),
)
with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}):
splitter = TokenTextSplitter(model_name="gpt-4", chunk_size=5, chunk_overlap=1)
assert splitter._tokenizer is fake_encoding
# ============================================================================
# Test EnhanceRecursiveCharacterTextSplitter
@ -731,6 +818,50 @@ class TestEnhanceRecursiveCharacterTextSplitter:
assert len(result) > 0
assert all(isinstance(chunk, str) for chunk in result)
def test_from_encoder_internal_token_encoder_paths(self):
"""
Test internal _token_encoder branches by capturing local closure from frame.
This validates:
- empty texts path
- embedding model path
- GPT2Tokenizer fallback path
- _character_encoder empty-path branch
"""
class _SpySplitter(EnhanceRecursiveCharacterTextSplitter):
captured_token_encoder = None
captured_character_encoder = None
def __init__(self, **kwargs):
frame = currentframe()
if frame and frame.f_back:
_SpySplitter.captured_token_encoder = frame.f_back.f_locals.get("_token_encoder")
_SpySplitter.captured_character_encoder = frame.f_back.f_locals.get("_character_encoder")
super().__init__(**kwargs)
mock_model = Mock()
mock_model.get_text_embedding_num_tokens.return_value = [3, 5]
_SpySplitter.from_encoder(embedding_model_instance=mock_model, chunk_size=10, chunk_overlap=1)
token_encoder = _SpySplitter.captured_token_encoder
character_encoder = _SpySplitter.captured_character_encoder
assert token_encoder is not None
assert character_encoder is not None
assert token_encoder([]) == []
assert token_encoder(["abc", "defgh"]) == [3, 5]
assert character_encoder([]) == []
with patch(
"core.rag.splitter.fixed_text_splitter.GPT2Tokenizer.get_num_tokens",
side_effect=lambda text: len(text) + 1,
):
_SpySplitter.from_encoder(embedding_model_instance=None, chunk_size=10, chunk_overlap=1)
token_encoder_without_model = _SpySplitter.captured_token_encoder
assert token_encoder_without_model is not None
assert token_encoder_without_model(["ab", "cdef"]) == [3, 5]
# ============================================================================
# Test FixedRecursiveCharacterTextSplitter
@ -908,6 +1039,56 @@ class TestFixedRecursiveCharacterTextSplitter:
chunks = splitter.split_text(data)
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
def test_recursive_split_keep_separator_and_recursive_fallback(self):
"""Cover keep-separator split branch and recursive _split_text fallback."""
text = "short." + ("x" * 60)
splitter = FixedRecursiveCharacterTextSplitter(
fixed_separator="",
separators=[".", " ", ""],
chunk_size=10,
chunk_overlap=2,
keep_separator=True,
)
chunks = splitter.recursive_split_text(text)
assert chunks
assert any("short." in chunk for chunk in chunks)
assert any(len(chunk) <= 12 for chunk in chunks)
def test_recursive_split_newline_separator_filtering(self):
"""Cover newline-specific empty filtering branch."""
text = "line1\n\nline2\n\nline3"
splitter = FixedRecursiveCharacterTextSplitter(
fixed_separator="",
separators=["\n", ""],
chunk_size=50,
chunk_overlap=5,
)
chunks = splitter.recursive_split_text(text)
assert chunks
assert all(chunk != "" for chunk in chunks)
assert "line1" in "".join(chunks)
assert "line2" in "".join(chunks)
assert "line3" in "".join(chunks)
def test_recursive_split_without_new_separator_appends_long_chunk(self):
"""Cover branch where no further separators exist and long split is appended directly."""
text = "aa\n" + ("b" * 40)
splitter = FixedRecursiveCharacterTextSplitter(
fixed_separator="",
separators=["\n"],
chunk_size=10,
chunk_overlap=2,
)
chunks = splitter.recursive_split_text(text)
assert "aa" in chunks
assert any(len(chunk) >= 40 for chunk in chunks)
# ============================================================================
# Test Metadata Preservation