mirror of
https://github.com/langgenius/dify.git
synced 2026-03-28 15:51:00 +08:00
test: unit test for core.rag module (#32630)
This commit is contained in:
parent
a5832df586
commit
a0ed350871
813
api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py
Normal file
813
api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py
Normal file
@ -0,0 +1,813 @@
|
||||
"""
|
||||
Unit tests for DatasetDocumentStore.
|
||||
|
||||
Tests cover all public methods and error paths of the DatasetDocumentStore class
|
||||
which provides document storage and retrieval functionality for datasets in the RAG system.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore, DocumentSegment
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreInit:
|
||||
"""Tests for DatasetDocumentStore initialization."""
|
||||
|
||||
def test_init_with_all_parameters(self):
|
||||
"""Test initialization with dataset, user_id, and document_id."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
assert store._dataset == mock_dataset
|
||||
assert store._user_id == "test-user-id"
|
||||
assert store._document_id == "test-doc-id"
|
||||
assert store.dataset_id == "test-dataset-id"
|
||||
assert store.user_id == "test-user-id"
|
||||
|
||||
def test_init_without_document_id(self):
|
||||
"""Test initialization without document_id."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
assert store._document_id is None
|
||||
assert store.dataset_id == "test-dataset-id"
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreSerialization:
|
||||
"""Tests for to_dict and from_dict methods."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization to dictionary."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.to_dict()
|
||||
|
||||
assert result == {"dataset_id": "test-dataset-id"}
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test deserialization from dictionary."""
|
||||
|
||||
config_dict = {
|
||||
"dataset": MagicMock(spec=["id"]),
|
||||
"user_id": "test-user",
|
||||
"document_id": "test-doc",
|
||||
}
|
||||
config_dict["dataset"].id = "ds-123"
|
||||
|
||||
store = DatasetDocumentStore.from_dict(config_dict)
|
||||
|
||||
assert store._user_id == "test-user"
|
||||
assert store._document_id == "test-doc"
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreDocs:
|
||||
"""Tests for the docs property."""
|
||||
|
||||
def test_docs_returns_document_dict(self):
|
||||
"""Test that docs property returns a dictionary of documents."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
mock_segment = MagicMock(spec=DocumentSegment)
|
||||
mock_segment.index_node_id = "node-1"
|
||||
mock_segment.index_node_hash = "hash-1"
|
||||
mock_segment.document_id = "doc-1"
|
||||
mock_segment.dataset_id = "test-dataset-id"
|
||||
mock_segment.content = "Test content"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.scalars.return_value.all.return_value = [mock_segment]
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.docs
|
||||
|
||||
assert "node-1" in result
|
||||
assert isinstance(result["node-1"], Document)
|
||||
|
||||
def test_docs_empty_dataset(self):
|
||||
"""Test docs property with no segments."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.docs
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreAddDocuments:
|
||||
"""Tests for add_documents method."""
|
||||
|
||||
def test_add_documents_new_document_with_embedding(self):
|
||||
"""Test adding new documents with embedding model."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
mock_dataset.indexing_technique = "high_quality"
|
||||
mock_dataset.embedding_model_provider = "provider"
|
||||
mock_dataset.embedding_model = "model"
|
||||
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
mock_doc.page_content = "Test content"
|
||||
mock_doc.metadata = {
|
||||
"doc_id": "doc-1",
|
||||
"doc_hash": "hash-1",
|
||||
}
|
||||
mock_doc.attachments = None
|
||||
mock_doc.children = None
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_text_embedding_num_tokens.return_value = [10]
|
||||
|
||||
with (
|
||||
patch("core.rag.docstore.dataset_docstore.db") as mock_db,
|
||||
patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_model_instance.return_value = mock_model_instance
|
||||
mock_manager_class.return_value = mock_manager
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_documents([mock_doc])
|
||||
|
||||
mock_db.session.add.assert_called()
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
def test_add_documents_update_existing_document(self):
|
||||
"""Test updating existing document with allow_update=True."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_dataset.embedding_model_provider = None
|
||||
mock_dataset.embedding_model = None
|
||||
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
mock_doc.page_content = "Updated content"
|
||||
mock_doc.metadata = {
|
||||
"doc_id": "doc-1",
|
||||
"doc_hash": "new-hash",
|
||||
}
|
||||
mock_doc.attachments = None
|
||||
mock_doc.children = None
|
||||
|
||||
mock_existing_segment = MagicMock()
|
||||
mock_existing_segment.id = "seg-1"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_documents([mock_doc])
|
||||
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
def test_add_documents_raises_when_not_allowed(self):
|
||||
"""Test that adding existing doc without allow_update raises ValueError."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
mock_doc.page_content = "Test content"
|
||||
mock_doc.metadata = {
|
||||
"doc_id": "doc-1",
|
||||
"doc_hash": "hash-1",
|
||||
}
|
||||
mock_doc.attachments = None
|
||||
mock_doc.children = None
|
||||
|
||||
mock_existing_segment = MagicMock()
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
store.add_documents([mock_doc], allow_update=False)
|
||||
|
||||
def test_add_documents_with_answer_metadata(self):
|
||||
"""Test adding document with answer in metadata."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
mock_doc.page_content = "Test content"
|
||||
mock_doc.metadata = {
|
||||
"doc_id": "doc-1",
|
||||
"doc_hash": "hash-1",
|
||||
"answer": "Test answer",
|
||||
}
|
||||
mock_doc.attachments = None
|
||||
mock_doc.children = None
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_documents([mock_doc])
|
||||
|
||||
mock_db.session.add.assert_called()
|
||||
|
||||
def test_add_documents_with_invalid_document_type(self):
|
||||
"""Test that non-Document raises ValueError."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="must be a Document"):
|
||||
store.add_documents(["not a document"])
|
||||
|
||||
def test_add_documents_with_none_metadata(self):
|
||||
"""Test that document with None metadata raises ValueError."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
mock_doc.page_content = "Test content"
|
||||
mock_doc.metadata = None
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="metadata must be a dict"):
|
||||
store.add_documents([mock_doc])
|
||||
|
||||
def test_add_documents_with_save_child(self):
|
||||
"""Test adding documents with save_child=True."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
|
||||
mock_child = MagicMock(spec=Document)
|
||||
mock_child.page_content = "Child content"
|
||||
mock_child.metadata = {
|
||||
"doc_id": "child-1",
|
||||
"doc_hash": "child-hash",
|
||||
}
|
||||
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
mock_doc.page_content = "Test content"
|
||||
mock_doc.metadata = {
|
||||
"doc_id": "doc-1",
|
||||
"doc_hash": "hash-1",
|
||||
}
|
||||
mock_doc.attachments = None
|
||||
mock_doc.children = [mock_child]
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_documents([mock_doc], save_child=True)
|
||||
|
||||
mock_db.session.add.assert_called()
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreExists:
|
||||
"""Tests for document_exists method."""
|
||||
|
||||
def test_document_exists_returns_true(self):
|
||||
"""Test document_exists returns True when segment exists."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
mock_segment = MagicMock()
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.document_exists("doc-1")
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_document_exists_returns_false(self):
|
||||
"""Test document_exists returns False when segment doesn't exist."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.document_exists("doc-1")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreGetDocument:
|
||||
"""Tests for get_document method."""
|
||||
|
||||
def test_get_document_success(self):
|
||||
"""Test getting a document successfully."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
mock_segment = MagicMock(spec=DocumentSegment)
|
||||
mock_segment.index_node_id = "node-1"
|
||||
mock_segment.index_node_hash = "hash-1"
|
||||
mock_segment.document_id = "doc-1"
|
||||
mock_segment.dataset_id = "test-dataset-id"
|
||||
mock_segment.content = "Test content"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.get_document("node-1", raise_error=False)
|
||||
|
||||
assert isinstance(result, Document)
|
||||
assert result.page_content == "Test content"
|
||||
|
||||
def test_get_document_returns_none_when_not_found(self):
|
||||
"""Test get_document returns None when not found and raise_error=False."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.get_document("nonexistent", raise_error=False)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_document_raises_when_not_found(self):
|
||||
"""Test get_document raises ValueError when not found and raise_error=True."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
store.get_document("nonexistent", raise_error=True)
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreDeleteDocument:
|
||||
"""Tests for delete_document method."""
|
||||
|
||||
def test_delete_document_success(self):
|
||||
"""Test deleting a document successfully."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
mock_segment = MagicMock()
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
store.delete_document("doc-1")
|
||||
|
||||
mock_db.session.delete.assert_called_with(mock_segment)
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
def test_delete_document_returns_none_when_not_found(self):
|
||||
"""Test delete_document returns None when not found and raise_error=False."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.delete_document("nonexistent", raise_error=False)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_delete_document_raises_when_not_found(self):
|
||||
"""Test delete_document raises ValueError when not found and raise_error=True."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
store.delete_document("nonexistent", raise_error=True)
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreHashOperations:
|
||||
"""Tests for set_document_hash and get_document_hash methods."""
|
||||
|
||||
def test_set_document_hash_success(self):
|
||||
"""Test setting document hash successfully."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.index_node_hash = "old-hash"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
store.set_document_hash("doc-1", "new-hash")
|
||||
|
||||
assert mock_segment.index_node_hash == "new-hash"
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
def test_set_document_hash_returns_none_when_not_found(self):
|
||||
"""Test set_document_hash returns None when segment not found."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.set_document_hash("nonexistent", "new-hash")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_document_hash_success(self):
|
||||
"""Test getting document hash successfully."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.index_node_hash = "test-hash"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.get_document_hash("doc-1")
|
||||
|
||||
assert result == "test-hash"
|
||||
|
||||
def test_get_document_hash_returns_none_when_not_found(self):
|
||||
"""Test get_document_hash returns None when segment not found."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db"):
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.get_document_hash("nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreSegment:
|
||||
"""Tests for get_document_segment method."""
|
||||
|
||||
def test_get_document_segment_returns_segment(self):
|
||||
"""Test getting a document segment."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
mock_segment = MagicMock(spec=DocumentSegment)
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.scalar.return_value = mock_segment
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.get_document_segment("doc-1")
|
||||
|
||||
assert result == mock_segment
|
||||
|
||||
def test_get_document_segment_returns_none(self):
|
||||
"""Test getting a non-existent document segment."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
)
|
||||
|
||||
result = store.get_document_segment("nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreMultimodelBinding:
|
||||
"""Tests for add_multimodel_documents_binding method."""
|
||||
|
||||
def test_add_multimodel_documents_binding_with_attachments(self):
|
||||
"""Test adding multimodel document bindings."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
|
||||
mock_attachment = MagicMock(spec=AttachmentDocument)
|
||||
mock_attachment.metadata = {"doc_id": "attachment-1"}
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_multimodel_documents_binding("seg-1", [mock_attachment])
|
||||
|
||||
mock_db.session.add.assert_called()
|
||||
|
||||
def test_add_multimodel_documents_binding_without_attachments(self):
|
||||
"""Test adding bindings with None attachments."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_multimodel_documents_binding("seg-1", None)
|
||||
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
def test_add_multimodel_documents_binding_with_empty_list(self):
|
||||
"""Test adding bindings with empty list."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_multimodel_documents_binding("seg-1", [])
|
||||
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreAddDocumentsUpdateChild:
|
||||
"""Tests for add_documents when updating existing documents with children."""
|
||||
|
||||
def test_add_documents_update_existing_with_children(self):
|
||||
"""Test updating existing document with save_child=True and children."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
|
||||
mock_child = MagicMock(spec=Document)
|
||||
mock_child.page_content = "Updated child content"
|
||||
mock_child.metadata = {
|
||||
"doc_id": "child-1",
|
||||
"doc_hash": "new-child-hash",
|
||||
}
|
||||
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
mock_doc.page_content = "Updated content"
|
||||
mock_doc.metadata = {
|
||||
"doc_id": "doc-1",
|
||||
"doc_hash": "new-hash",
|
||||
}
|
||||
mock_doc.attachments = None
|
||||
mock_doc.children = [mock_child]
|
||||
|
||||
mock_existing_segment = MagicMock()
|
||||
mock_existing_segment.id = "seg-1"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_documents([mock_doc], save_child=True)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.delete.assert_called()
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
|
||||
class TestDatasetDocumentStoreAddDocumentsUpdateAnswer:
|
||||
"""Tests for add_documents when updating existing documents with answer metadata."""
|
||||
|
||||
def test_add_documents_update_existing_with_answer(self):
|
||||
"""Test updating existing document with answer in metadata."""
|
||||
|
||||
mock_dataset = MagicMock(spec=Dataset)
|
||||
mock_dataset.id = "test-dataset-id"
|
||||
mock_dataset.tenant_id = "tenant-1"
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
mock_doc.page_content = "Updated content"
|
||||
mock_doc.metadata = {
|
||||
"doc_id": "doc-1",
|
||||
"doc_hash": "new-hash",
|
||||
"answer": "Updated answer",
|
||||
}
|
||||
mock_doc.attachments = None
|
||||
mock_doc.children = None
|
||||
|
||||
mock_existing_segment = MagicMock()
|
||||
mock_existing_segment.id = "seg-1"
|
||||
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
store = DatasetDocumentStore(
|
||||
dataset=mock_dataset,
|
||||
user_id="test-user-id",
|
||||
document_id="test-doc-id",
|
||||
)
|
||||
|
||||
store.add_documents([mock_doc])
|
||||
|
||||
mock_db.session.commit.assert_called()
|
||||
555
api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py
Normal file
555
api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py
Normal file
@ -0,0 +1,555 @@
|
||||
"""Unit tests for cached_embedding.py - CacheEmbedding class.
|
||||
|
||||
This test file covers the methods not fully tested in test_embedding_service.py:
|
||||
- embed_multimodal_documents
|
||||
- embed_multimodal_query
|
||||
- Error handling scenarios in embed_query (DEBUG mode)
|
||||
"""
|
||||
|
||||
import base64
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
|
||||
from models.dataset import Embedding
|
||||
|
||||
|
||||
class TestCacheEmbeddingMultimodalDocuments:
|
||||
"""Test suite for CacheEmbedding.embed_multimodal_documents method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "vision-embedding-model"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
|
||||
model_type_instance = Mock()
|
||||
model_instance.model_type_instance = model_type_instance
|
||||
|
||||
model_schema = Mock()
|
||||
model_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10}
|
||||
model_type_instance.get_model_schema.return_value = model_schema
|
||||
|
||||
return model_instance
|
||||
|
||||
@pytest.fixture
|
||||
def sample_multimodal_result(self):
|
||||
"""Create a sample multimodal EmbeddingResult."""
|
||||
embedding_vector = np.random.randn(1536)
|
||||
normalized_vector = (embedding_vector / np.linalg.norm(embedding_vector)).tolist()
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=10,
|
||||
total_tokens=10,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal("0.000001"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
return EmbeddingResult(
|
||||
model="vision-embedding-model",
|
||||
embeddings=[normalized_vector],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result):
|
||||
"""Test embedding a single multimodal document when cache is empty."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance, user="test-user")
|
||||
documents = [{"file_id": "file123", "content": "test content"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], list)
|
||||
assert len(result[0]) == 1536
|
||||
|
||||
mock_model_instance.invoke_multimodal_embedding.assert_called_once()
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_embed_multiple_multimodal_documents_cache_miss(self, mock_model_instance):
|
||||
"""Test embedding multiple multimodal documents when cache is empty."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
documents = [
|
||||
{"file_id": "file1", "content": "content 1"},
|
||||
{"file_id": "file2", "content": "content 2"},
|
||||
{"file_id": "file3", "content": "content 3"},
|
||||
]
|
||||
|
||||
embeddings = []
|
||||
for _ in range(3):
|
||||
vector = np.random.randn(1536)
|
||||
normalized = (vector / np.linalg.norm(vector)).tolist()
|
||||
embeddings.append(normalized)
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=30,
|
||||
total_tokens=30,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal("0.000003"),
|
||||
currency="USD",
|
||||
latency=0.8,
|
||||
)
|
||||
|
||||
embedding_result = EmbeddingResult(
|
||||
model="vision-embedding-model",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
assert len(result) == 3
|
||||
assert all(len(emb) == 1536 for emb in result)
|
||||
|
||||
def test_embed_multimodal_documents_cache_hit(self, mock_model_instance):
|
||||
"""Test embedding multimodal documents when embeddings are cached."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
documents = [{"file_id": "file123"}]
|
||||
|
||||
cached_vector = np.random.randn(1536)
|
||||
normalized_cached = (cached_vector / np.linalg.norm(cached_vector)).tolist()
|
||||
|
||||
mock_cached_embedding = Mock(spec=Embedding)
|
||||
mock_cached_embedding.get_embedding.return_value = normalized_cached
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == normalized_cached
|
||||
mock_model_instance.invoke_multimodal_embedding.assert_not_called()
|
||||
|
||||
def test_embed_multimodal_documents_partial_cache_hit(self, mock_model_instance):
|
||||
"""Test embedding multimodal documents with mixed cache hits and misses."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
documents = [
|
||||
{"file_id": "cached_file"},
|
||||
{"file_id": "new_file_1"},
|
||||
{"file_id": "new_file_2"},
|
||||
]
|
||||
|
||||
cached_vector = np.random.randn(1536)
|
||||
normalized_cached = (cached_vector / np.linalg.norm(cached_vector)).tolist()
|
||||
|
||||
mock_cached_embedding = Mock(spec=Embedding)
|
||||
mock_cached_embedding.get_embedding.return_value = normalized_cached
|
||||
|
||||
new_embeddings = []
|
||||
for _ in range(2):
|
||||
vector = np.random.randn(1536)
|
||||
normalized = (vector / np.linalg.norm(vector)).tolist()
|
||||
new_embeddings.append(normalized)
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=20,
|
||||
total_tokens=20,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal("0.000002"),
|
||||
currency="USD",
|
||||
latency=0.6,
|
||||
)
|
||||
|
||||
embedding_result = EmbeddingResult(
|
||||
model="vision-embedding-model",
|
||||
embeddings=new_embeddings,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
call_count = [0]
|
||||
|
||||
def mock_filter_by(**kwargs):
|
||||
call_count[0] += 1
|
||||
mock_query = Mock()
|
||||
if call_count[0] == 1:
|
||||
mock_query.first.return_value = mock_cached_embedding
|
||||
else:
|
||||
mock_query.first.return_value = None
|
||||
return mock_query
|
||||
|
||||
mock_session.query.return_value.filter_by = mock_filter_by
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0] == normalized_cached
|
||||
|
||||
def test_embed_multimodal_documents_nan_handling(self, mock_model_instance):
|
||||
"""Test handling of NaN values in multimodal embeddings."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
documents = [{"file_id": "valid"}, {"file_id": "nan"}]
|
||||
|
||||
valid_vector = np.random.randn(1536).tolist()
|
||||
nan_vector = [float("nan")] * 1536
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=20,
|
||||
total_tokens=20,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal("0.000002"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
embedding_result = EmbeddingResult(
|
||||
model="vision-embedding-model",
|
||||
embeddings=[valid_vector, nan_vector],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] is not None
|
||||
assert result[1] is None
|
||||
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_embed_multimodal_documents_large_batch(self, mock_model_instance):
|
||||
"""Test embedding large batch of multimodal documents respecting MAX_CHUNKS."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
documents = [{"file_id": f"file{i}"} for i in range(25)]
|
||||
|
||||
def create_batch_result(batch_size):
|
||||
embeddings = []
|
||||
for _ in range(batch_size):
|
||||
vector = np.random.randn(1536)
|
||||
normalized = (vector / np.linalg.norm(vector)).tolist()
|
||||
embeddings.append(normalized)
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=batch_size * 10,
|
||||
total_tokens=batch_size * 10,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal(str(batch_size * 0.000001)),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
return EmbeddingResult(
|
||||
model="vision-embedding-model",
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)]
|
||||
mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
assert len(result) == 25
|
||||
assert mock_model_instance.invoke_multimodal_embedding.call_count == 3
|
||||
|
||||
def test_embed_multimodal_documents_api_error(self, mock_model_instance):
|
||||
"""Test handling of API errors during multimodal embedding."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
documents = [{"file_id": "file123"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
assert "API Error" in str(exc_info.value)
|
||||
mock_session.rollback.assert_called()
|
||||
|
||||
def test_embed_multimodal_documents_integrity_error_during_transform(
|
||||
self, mock_model_instance, sample_multimodal_result
|
||||
):
|
||||
"""Test handling of IntegrityError during embedding transformation."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
documents = [{"file_id": "file123"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
|
||||
|
||||
mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None)
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
assert len(result) == 1
|
||||
mock_session.rollback.assert_called()
|
||||
|
||||
|
||||
class TestCacheEmbeddingMultimodalQuery:
|
||||
"""Test suite for CacheEmbedding.embed_multimodal_query method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "vision-embedding-model"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
return model_instance
|
||||
|
||||
def test_embed_multimodal_query_cache_miss(self, mock_model_instance):
|
||||
"""Test embedding multimodal query when Redis cache is empty."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance, user="test-user")
|
||||
document = {"file_id": "file123"}
|
||||
|
||||
vector = np.random.randn(1536)
|
||||
normalized = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=5,
|
||||
total_tokens=5,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal("0.0000005"),
|
||||
currency="USD",
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = EmbeddingResult(
|
||||
model="vision-embedding-model",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_query(document)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1536
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_embed_multimodal_query_cache_hit(self, mock_model_instance):
|
||||
"""Test embedding multimodal query when Redis cache has the value."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
document = {"file_id": "file123"}
|
||||
|
||||
embedding_vector = np.random.randn(1536)
|
||||
vector_bytes = embedding_vector.tobytes()
|
||||
encoded_vector = base64.b64encode(vector_bytes).decode("utf-8")
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = encoded_vector.encode()
|
||||
|
||||
result = cache_embedding.embed_multimodal_query(document)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1536
|
||||
mock_redis.expire.assert_called_once()
|
||||
mock_model_instance.invoke_multimodal_embedding.assert_not_called()
|
||||
|
||||
def test_embed_multimodal_query_nan_handling(self, mock_model_instance):
|
||||
"""Test handling of NaN values in multimodal query embeddings."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
|
||||
nan_vector = [float("nan")] * 1536
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=5,
|
||||
total_tokens=5,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal("0.0000005"),
|
||||
currency="USD",
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = EmbeddingResult(
|
||||
model="vision-embedding-model",
|
||||
embeddings=[nan_vector],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
document = {"file_id": "file123"}
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
cache_embedding.embed_multimodal_query(document)
|
||||
|
||||
assert "Normalized embedding is nan" in str(exc_info.value)
|
||||
|
||||
def test_embed_multimodal_query_api_error(self, mock_model_instance):
|
||||
"""Test handling of API errors during multimodal query embedding."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
document = {"file_id": "file123"}
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
|
||||
mock_config.DEBUG = False
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
cache_embedding.embed_multimodal_query(document)
|
||||
|
||||
assert "API Error" in str(exc_info.value)
|
||||
|
||||
def test_embed_multimodal_query_redis_set_error(self, mock_model_instance):
|
||||
"""Test handling of Redis set errors during multimodal query embedding."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
document = {"file_id": "file123"}
|
||||
|
||||
vector = np.random.randn(1536)
|
||||
normalized = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=5,
|
||||
total_tokens=5,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal("0.0000005"),
|
||||
currency="USD",
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = EmbeddingResult(
|
||||
model="vision-embedding-model",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
mock_redis.setex.side_effect = RuntimeError("Redis Error")
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
|
||||
mock_config.DEBUG = True
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
cache_embedding.embed_multimodal_query(document)
|
||||
|
||||
|
||||
class TestCacheEmbeddingQueryErrors:
|
||||
"""Test suite for error handling in CacheEmbedding.embed_query method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
return model_instance
|
||||
|
||||
def test_embed_query_api_error_debug_mode(self, mock_model_instance):
|
||||
"""Test handling of API errors in debug mode."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
query = "test query"
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.side_effect = RuntimeError("API Error")
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
|
||||
mock_config.DEBUG = True
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
cache_embedding.embed_query(query)
|
||||
|
||||
assert "API Error" in str(exc_info.value)
|
||||
mock_logger.exception.assert_called()
|
||||
|
||||
def test_embed_query_redis_set_error_debug_mode(self, mock_model_instance):
|
||||
"""Test handling of Redis set errors in debug mode."""
|
||||
cache_embedding = CacheEmbedding(mock_model_instance)
|
||||
query = "test query"
|
||||
|
||||
vector = np.random.randn(1536)
|
||||
normalized = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=5,
|
||||
total_tokens=5,
|
||||
unit_price=Decimal("0.0001"),
|
||||
price_unit=Decimal(1000),
|
||||
total_price=Decimal("0.0000005"),
|
||||
currency="USD",
|
||||
latency=0.3,
|
||||
)
|
||||
|
||||
embedding_result = EmbeddingResult(
|
||||
model="text-embedding-ada-002",
|
||||
embeddings=[normalized],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
mock_redis.setex.side_effect = RuntimeError("Redis Error")
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config:
|
||||
mock_config.DEBUG = True
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
|
||||
with pytest.raises(RuntimeError):
|
||||
cache_embedding.embed_query(query)
|
||||
|
||||
mock_logger.exception.assert_called()
|
||||
|
||||
|
||||
class TestCacheEmbeddingInitialization:
|
||||
"""Test suite for CacheEmbedding initialization."""
|
||||
|
||||
def test_initialization_with_user(self):
|
||||
"""Test CacheEmbedding initialization with user parameter."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
|
||||
cache_embedding = CacheEmbedding(model_instance, user="test-user")
|
||||
|
||||
assert cache_embedding._model_instance == model_instance
|
||||
assert cache_embedding._user == "test-user"
|
||||
|
||||
def test_initialization_without_user(self):
|
||||
"""Test CacheEmbedding initialization without user parameter."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
|
||||
cache_embedding = CacheEmbedding(model_instance)
|
||||
|
||||
assert cache_embedding._model_instance == model_instance
|
||||
assert cache_embedding._user is None
|
||||
220
api/tests/unit_tests/core/rag/embedding/test_embedding_base.py
Normal file
220
api/tests/unit_tests/core/rag/embedding/test_embedding_base.py
Normal file
@ -0,0 +1,220 @@
|
||||
"""Unit tests for embedding_base.py - the abstract Embeddings base class."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
|
||||
|
||||
class ConcreteEmbeddings(Embeddings):
|
||||
"""Concrete implementation of Embeddings for testing."""
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [[1.0] * 10 for _ in texts]
|
||||
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return [[1.0] * 10 for _ in multimodel_documents]
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return [1.0] * 10
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return [1.0] * 10
|
||||
|
||||
|
||||
class TestEmbeddingsBase:
|
||||
"""Test suite for the abstract Embeddings base class."""
|
||||
|
||||
def test_embeddings_is_abc(self):
|
||||
"""Test that Embeddings is an abstract base class."""
|
||||
assert hasattr(Embeddings, "__abstractmethods__")
|
||||
assert len(Embeddings.__abstractmethods__) > 0
|
||||
|
||||
def test_embed_documents_is_abstract(self):
|
||||
"""Test that embed_documents is an abstract method."""
|
||||
assert "embed_documents" in Embeddings.__abstractmethods__
|
||||
|
||||
def test_embed_multimodal_documents_is_abstract(self):
|
||||
"""Test that embed_multimodal_documents is an abstract method."""
|
||||
assert "embed_multimodal_documents" in Embeddings.__abstractmethods__
|
||||
|
||||
def test_embed_query_is_abstract(self):
|
||||
"""Test that embed_query is an abstract method."""
|
||||
assert "embed_query" in Embeddings.__abstractmethods__
|
||||
|
||||
def test_embed_multimodal_query_is_abstract(self):
|
||||
"""Test that embed_multimodal_query is an abstract method."""
|
||||
assert "embed_multimodal_query" in Embeddings.__abstractmethods__
|
||||
|
||||
def test_embed_documents_raises_not_implemented(self):
|
||||
"""Test that embed_documents raises NotImplementedError in its body."""
|
||||
source = inspect.getsource(Embeddings.embed_documents)
|
||||
assert "raise NotImplementedError" in source
|
||||
|
||||
def test_embed_multimodal_documents_raises_not_implemented(self):
|
||||
"""Test that embed_multimodal_documents raises NotImplementedError in its body."""
|
||||
source = inspect.getsource(Embeddings.embed_multimodal_documents)
|
||||
assert "raise NotImplementedError" in source
|
||||
|
||||
def test_embed_query_raises_not_implemented(self):
|
||||
"""Test that embed_query raises NotImplementedError in its body."""
|
||||
source = inspect.getsource(Embeddings.embed_query)
|
||||
assert "raise NotImplementedError" in source
|
||||
|
||||
def test_embed_multimodal_query_raises_not_implemented(self):
|
||||
"""Test that embed_multimodal_query raises NotImplementedError in its body."""
|
||||
source = inspect.getsource(Embeddings.embed_multimodal_query)
|
||||
assert "raise NotImplementedError" in source
|
||||
|
||||
def test_aembed_documents_raises_not_implemented(self):
|
||||
"""Test that aembed_documents raises NotImplementedError in its body."""
|
||||
source = inspect.getsource(Embeddings.aembed_documents)
|
||||
assert "raise NotImplementedError" in source
|
||||
|
||||
def test_aembed_query_raises_not_implemented(self):
|
||||
"""Test that aembed_query raises NotImplementedError in its body."""
|
||||
source = inspect.getsource(Embeddings.aembed_query)
|
||||
assert "raise NotImplementedError" in source
|
||||
|
||||
def test_concrete_implementation_works(self):
|
||||
"""Test that a concrete implementation of Embeddings works correctly."""
|
||||
concrete = ConcreteEmbeddings()
|
||||
result = concrete.embed_documents(["test1", "test2"])
|
||||
assert len(result) == 2
|
||||
assert all(len(emb) == 10 for emb in result)
|
||||
|
||||
def test_concrete_implementation_embed_query(self):
|
||||
"""Test concrete implementation of embed_query."""
|
||||
concrete = ConcreteEmbeddings()
|
||||
result = concrete.embed_query("test query")
|
||||
assert len(result) == 10
|
||||
|
||||
def test_concrete_implementation_embed_multimodal_documents(self):
|
||||
"""Test concrete implementation of embed_multimodal_documents."""
|
||||
concrete = ConcreteEmbeddings()
|
||||
docs: list[dict[str, Any]] = [{"file_id": "file1"}, {"file_id": "file2"}]
|
||||
result = concrete.embed_multimodal_documents(docs)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_concrete_implementation_embed_multimodal_query(self):
|
||||
"""Test concrete implementation of embed_multimodal_query."""
|
||||
concrete = ConcreteEmbeddings()
|
||||
result = concrete.embed_multimodal_query({"file_id": "test"})
|
||||
assert len(result) == 10
|
||||
|
||||
|
||||
class TestEmbeddingsNotImplemented:
|
||||
"""Test that abstract methods raise NotImplementedError when called."""
|
||||
|
||||
def test_embed_query_raises_not_implemented(self):
|
||||
"""Test that embed_query raises NotImplementedError."""
|
||||
|
||||
class PartialImpl:
|
||||
pass
|
||||
|
||||
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
|
||||
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
|
||||
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
|
||||
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
|
||||
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
|
||||
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
|
||||
|
||||
partial = PartialImpl()
|
||||
with pytest.raises(NotImplementedError):
|
||||
partial.embed_query("test")
|
||||
|
||||
def test_embed_documents_raises_not_implemented(self):
|
||||
"""Test that embed_documents raises NotImplementedError."""
|
||||
|
||||
class PartialImpl:
|
||||
pass
|
||||
|
||||
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
|
||||
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
|
||||
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
|
||||
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
|
||||
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
|
||||
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
|
||||
|
||||
partial = PartialImpl()
|
||||
with pytest.raises(NotImplementedError):
|
||||
partial.embed_documents(["test"])
|
||||
|
||||
def test_embed_multimodal_documents_raises_not_implemented(self):
|
||||
"""Test that embed_multimodal_documents raises NotImplementedError."""
|
||||
|
||||
class PartialImpl:
|
||||
pass
|
||||
|
||||
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
|
||||
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
|
||||
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
|
||||
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
|
||||
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
|
||||
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
|
||||
|
||||
partial = PartialImpl()
|
||||
with pytest.raises(NotImplementedError):
|
||||
partial.embed_multimodal_documents([{"file_id": "test"}])
|
||||
|
||||
def test_embed_multimodal_query_raises_not_implemented(self):
|
||||
"""Test that embed_multimodal_query raises NotImplementedError."""
|
||||
|
||||
class PartialImpl:
|
||||
pass
|
||||
|
||||
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
|
||||
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
|
||||
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
|
||||
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
|
||||
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
|
||||
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
|
||||
|
||||
partial = PartialImpl()
|
||||
with pytest.raises(NotImplementedError):
|
||||
partial.embed_multimodal_query({"file_id": "test"})
|
||||
|
||||
def test_aembed_documents_raises_not_implemented(self):
|
||||
"""Test that aembed_documents raises NotImplementedError."""
|
||||
|
||||
class PartialImpl:
|
||||
pass
|
||||
|
||||
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
|
||||
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
|
||||
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
|
||||
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
|
||||
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
|
||||
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
|
||||
|
||||
partial = PartialImpl()
|
||||
|
||||
async def run_test():
|
||||
with pytest.raises(NotImplementedError):
|
||||
await partial.aembed_documents(["test"])
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_aembed_query_raises_not_implemented(self):
|
||||
"""Test that aembed_query raises NotImplementedError."""
|
||||
|
||||
class PartialImpl:
|
||||
pass
|
||||
|
||||
PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text)
|
||||
PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts)
|
||||
PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs)
|
||||
PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc)
|
||||
PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts)
|
||||
PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text)
|
||||
|
||||
partial = PartialImpl()
|
||||
|
||||
async def run_test():
|
||||
with pytest.raises(NotImplementedError):
|
||||
await partial.aembed_query("test")
|
||||
|
||||
asyncio.run(run_test())
|
||||
85
api/tests/unit_tests/core/rag/extractor/blob/test_blob.py
Normal file
85
api/tests/unit_tests/core/rag/extractor/blob/test_blob.py
Normal file
@ -0,0 +1,85 @@
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
|
||||
|
||||
class TestBlob:
|
||||
def test_requires_data_or_path(self):
|
||||
with pytest.raises(ValueError, match="Either data or path must be provided"):
|
||||
Blob()
|
||||
|
||||
def test_source_property_and_repr_include_path(self, tmp_path):
|
||||
file_path = tmp_path / "sample.txt"
|
||||
file_path.write_text("hello", encoding="utf-8")
|
||||
|
||||
blob = Blob.from_path(str(file_path))
|
||||
|
||||
assert blob.source == str(file_path)
|
||||
assert str(file_path) in repr(blob)
|
||||
|
||||
def test_as_string_from_bytes_and_str(self):
|
||||
assert Blob.from_data(b"abc").as_string() == "abc"
|
||||
assert Blob.from_data("plain-text").as_string() == "plain-text"
|
||||
|
||||
def test_as_string_from_path(self, tmp_path):
|
||||
file_path = tmp_path / "sample.txt"
|
||||
file_path.write_text("from-file", encoding="utf-8")
|
||||
|
||||
blob = Blob.from_path(str(file_path))
|
||||
|
||||
assert blob.as_string() == "from-file"
|
||||
|
||||
def test_as_string_raises_for_invalid_state(self):
|
||||
blob = Blob.model_construct(data=None, path=None, mimetype=None, encoding="utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="Unable to get string for blob"):
|
||||
blob.as_string()
|
||||
|
||||
def test_as_bytes_from_bytes_str_and_path(self, tmp_path):
|
||||
from_bytes = Blob.from_data(b"abc")
|
||||
from_str = Blob.from_data("abc", encoding="utf-8")
|
||||
|
||||
file_path = tmp_path / "sample.bin"
|
||||
file_path.write_bytes(b"from-path")
|
||||
from_path = Blob.from_path(str(file_path))
|
||||
|
||||
assert from_bytes.as_bytes() == b"abc"
|
||||
assert from_str.as_bytes() == b"abc"
|
||||
assert from_path.as_bytes() == b"from-path"
|
||||
|
||||
def test_as_bytes_raises_for_invalid_state(self):
|
||||
blob = Blob.model_construct(data=None, path=None, mimetype=None, encoding="utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="Unable to get bytes for blob"):
|
||||
blob.as_bytes()
|
||||
|
||||
def test_as_bytes_io_for_bytes_and_path(self, tmp_path):
|
||||
data_blob = Blob.from_data(b"bytes-io")
|
||||
with data_blob.as_bytes_io() as stream:
|
||||
assert isinstance(stream, BytesIO)
|
||||
assert stream.read() == b"bytes-io"
|
||||
|
||||
file_path = tmp_path / "stream.bin"
|
||||
file_path.write_bytes(b"path-stream")
|
||||
path_blob = Blob.from_path(str(file_path))
|
||||
with path_blob.as_bytes_io() as stream:
|
||||
assert stream.read() == b"path-stream"
|
||||
|
||||
def test_as_bytes_io_raises_for_unsupported_data_type(self):
|
||||
blob = Blob.from_data("text-value")
|
||||
|
||||
with pytest.raises(NotImplementedError, match="Unable to convert blob"):
|
||||
with blob.as_bytes_io():
|
||||
pass
|
||||
|
||||
def test_from_path_respects_guessing_and_explicit_mime(self, tmp_path):
|
||||
file_path = tmp_path / "example.txt"
|
||||
file_path.write_text("x", encoding="utf-8")
|
||||
|
||||
guessed = Blob.from_path(str(file_path))
|
||||
explicit = Blob.from_path(str(file_path), mime_type="custom/type", guess_type=False)
|
||||
|
||||
assert guessed.mimetype == "text/plain"
|
||||
assert explicit.mimetype == "custom/type"
|
||||
@ -1,61 +1,337 @@
|
||||
import os
|
||||
"""Unit tests for Firecrawl app and extractor integration points."""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
import core.rag.extractor.firecrawl.firecrawl_app as firecrawl_module
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||
|
||||
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||
url = "https://firecrawl.dev"
|
||||
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
|
||||
base_url = "https://api.firecrawl.dev"
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
|
||||
params = {
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"maxDepth": 1,
|
||||
"limit": 1,
|
||||
}
|
||||
mocked_firecrawl = {
|
||||
"id": "test",
|
||||
}
|
||||
mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl))
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
|
||||
assert job_id is not None
|
||||
assert isinstance(job_id, str)
|
||||
def _response(status_code: int, json_data: Mapping[str, Any] | None = None, text: str = "") -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
response.text = text
|
||||
response.json.return_value = json_data if json_data is not None else {}
|
||||
return response
|
||||
|
||||
|
||||
def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture):
|
||||
api_key = "fc-"
|
||||
base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"]
|
||||
for base in base_urls:
|
||||
app = FirecrawlApp(api_key=api_key, base_url=base)
|
||||
mock_post = mocker.patch("httpx.post")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"id": "job123"}
|
||||
mock_post.return_value = mock_resp
|
||||
app.crawl_url("https://example.com", params=None)
|
||||
called_url = mock_post.call_args[0][0]
|
||||
assert called_url == "https://custom.firecrawl.dev/v2/crawl"
|
||||
class TestFirecrawlApp:
|
||||
def test_init_requires_api_key_for_default_base_url(self):
|
||||
with pytest.raises(ValueError, match="No API key provided"):
|
||||
FirecrawlApp(api_key=None, base_url="https://api.firecrawl.dev")
|
||||
|
||||
def test_prepare_headers_and_build_url(self):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev/")
|
||||
|
||||
assert app._prepare_headers() == {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer fc-key",
|
||||
}
|
||||
assert app._build_url("/v2/crawl") == "https://custom.firecrawl.dev/v2/crawl"
|
||||
|
||||
def test_scrape_url_success(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch(
|
||||
"httpx.post",
|
||||
return_value=_response(
|
||||
200,
|
||||
{
|
||||
"data": {
|
||||
"metadata": {
|
||||
"title": "t",
|
||||
"description": "d",
|
||||
"sourceURL": "https://example.com",
|
||||
},
|
||||
"markdown": "body",
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
result = app.scrape_url("https://example.com", params={"onlyMainContent": False})
|
||||
|
||||
assert result == {
|
||||
"title": "t",
|
||||
"description": "d",
|
||||
"source_url": "https://example.com",
|
||||
"markdown": "body",
|
||||
}
|
||||
|
||||
def test_scrape_url_handles_known_error_status(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("boom"))
|
||||
mocker.patch("httpx.post", return_value=_response(429, {"error": "limit"}))
|
||||
|
||||
with pytest.raises(Exception, match="boom"):
|
||||
app.scrape_url("https://example.com")
|
||||
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_scrape_url_unknown_status_raises(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.post", return_value=_response(404, text="Not Found"))
|
||||
|
||||
with pytest.raises(Exception, match="Failed to scrape URL. Status code: 404"):
|
||||
app.scrape_url("https://example.com")
|
||||
|
||||
def test_crawl_url_success(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.post", return_value=_response(200, {"id": "job-1"}))
|
||||
|
||||
assert app.crawl_url("https://example.com") == "job-1"
|
||||
|
||||
def test_crawl_url_non_200_uses_error_handler(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl failed"))
|
||||
mocker.patch("httpx.post", return_value=_response(500, {"error": "server"}))
|
||||
|
||||
with pytest.raises(Exception, match="crawl failed"):
|
||||
app.crawl_url("https://example.com")
|
||||
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_map_success(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.post", return_value=_response(200, {"success": True, "links": ["a", "b"]}))
|
||||
|
||||
assert app.map("https://example.com") == {"success": True, "links": ["a", "b"]}
|
||||
|
||||
def test_map_known_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"}))
|
||||
|
||||
assert app.map("https://example.com") == {}
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_map_unknown_error_raises(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.post", return_value=_response(418, text="teapot"))
|
||||
|
||||
with pytest.raises(Exception, match="Failed to start map job. Status code: 418"):
|
||||
app.map("https://example.com")
|
||||
|
||||
def test_check_crawl_status_completed_with_data(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
payload = {
|
||||
"status": "completed",
|
||||
"total": 2,
|
||||
"completed": 2,
|
||||
"data": [
|
||||
{
|
||||
"metadata": {"title": "a", "description": "desc-a", "sourceURL": "https://a"},
|
||||
"markdown": "m-a",
|
||||
},
|
||||
{
|
||||
"metadata": {"title": "b", "description": "desc-b", "sourceURL": "https://b"},
|
||||
"markdown": "m-b",
|
||||
},
|
||||
{"metadata": {"title": "skip"}},
|
||||
],
|
||||
}
|
||||
mocker.patch("httpx.get", return_value=_response(200, payload))
|
||||
|
||||
save_calls: list[tuple[str, bytes]] = []
|
||||
delete_calls: list[str] = []
|
||||
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.exists.return_value = True
|
||||
mock_storage.delete.side_effect = lambda key: delete_calls.append(key)
|
||||
mock_storage.save.side_effect = lambda key, data: save_calls.append((key, data))
|
||||
mocker.patch.object(firecrawl_module, "storage", mock_storage)
|
||||
|
||||
result = app.check_crawl_status("job-42")
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["total"] == 2
|
||||
assert result["current"] == 2
|
||||
assert len(result["data"]) == 2
|
||||
assert delete_calls == ["website_files/job-42.txt"]
|
||||
assert len(save_calls) == 1
|
||||
assert save_calls[0][0] == "website_files/job-42.txt"
|
||||
|
||||
def test_check_crawl_status_completed_with_zero_total_raises(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": 0, "data": []}))
|
||||
|
||||
with pytest.raises(Exception, match="No page found"):
|
||||
app.check_crawl_status("job-1")
|
||||
|
||||
def test_check_crawl_status_non_completed(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
payload = {"status": "processing", "total": 5, "completed": 1, "data": []}
|
||||
mocker.patch("httpx.get", return_value=_response(200, payload))
|
||||
|
||||
assert app.check_crawl_status("job-1") == {
|
||||
"status": "processing",
|
||||
"total": 5,
|
||||
"current": 1,
|
||||
"data": [],
|
||||
}
|
||||
|
||||
def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mocker.patch("httpx.get", return_value=_response(500, {"error": "server"}))
|
||||
|
||||
assert app.check_crawl_status("job-1") == {}
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
payload = {
|
||||
"status": "completed",
|
||||
"total": 1,
|
||||
"completed": 1,
|
||||
"data": [{"metadata": {"title": "a", "sourceURL": "https://a"}, "markdown": "m-a"}],
|
||||
}
|
||||
mocker.patch("httpx.get", return_value=_response(200, payload))
|
||||
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.exists.return_value = False
|
||||
mock_storage.save.side_effect = RuntimeError("save failed")
|
||||
mocker.patch.object(firecrawl_module, "storage", mock_storage)
|
||||
|
||||
with pytest.raises(Exception, match="Error saving crawl data"):
|
||||
app.check_crawl_status("job-err")
|
||||
|
||||
def test_extract_common_fields_and_status_formatter(self):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
|
||||
fields = app._extract_common_fields(
|
||||
{"metadata": {"title": "t", "description": "d", "sourceURL": "u"}, "markdown": "m"}
|
||||
)
|
||||
assert fields == {"title": "t", "description": "d", "source_url": "u", "markdown": "m"}
|
||||
|
||||
status = app._format_crawl_status_response("completed", {"total": 1, "completed": 1}, [fields])
|
||||
assert status == {"status": "completed", "total": 1, "current": 1, "data": [fields]}
|
||||
|
||||
def test_post_and_get_request_retry_logic(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
sleep_mock = mocker.patch.object(firecrawl_module.time, "sleep")
|
||||
|
||||
resp_502_a = _response(502)
|
||||
resp_502_b = _response(502)
|
||||
resp_200 = _response(200)
|
||||
|
||||
mocker.patch("httpx.post", side_effect=[resp_502_a, resp_200])
|
||||
post_result = app._post_request("u", {"x": 1}, {"h": 1}, retries=3, backoff_factor=0.5)
|
||||
assert post_result is resp_200
|
||||
|
||||
mocker.patch("httpx.get", side_effect=[resp_502_b, _response(200)])
|
||||
get_result = app._get_request("u", {"h": 1}, retries=3, backoff_factor=0.25)
|
||||
assert get_result.status_code == 200
|
||||
|
||||
assert sleep_mock.call_count == 2
|
||||
|
||||
def test_post_and_get_request_return_last_502(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
sleep_mock = mocker.patch.object(firecrawl_module.time, "sleep")
|
||||
|
||||
last_post = _response(502)
|
||||
mocker.patch("httpx.post", side_effect=[_response(502), last_post])
|
||||
assert app._post_request("u", {}, {}, retries=2).status_code == 502
|
||||
|
||||
last_get = _response(502)
|
||||
mocker.patch("httpx.get", side_effect=[_response(502), last_get])
|
||||
assert app._get_request("u", {}, retries=2).status_code == 502
|
||||
|
||||
assert sleep_mock.call_count == 4
|
||||
|
||||
def test_handle_error_with_json_and_plain_text(self):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
|
||||
json_error = _response(400, {"message": "bad request"})
|
||||
with pytest.raises(Exception, match="bad request"):
|
||||
app._handle_error(json_error, "run task")
|
||||
|
||||
non_json = MagicMock()
|
||||
non_json.status_code = 400
|
||||
non_json.text = "plain error"
|
||||
non_json.json.side_effect = json.JSONDecodeError("bad", "x", 0)
|
||||
|
||||
with pytest.raises(Exception, match="plain error"):
|
||||
app._handle_error(non_json, "run task")
|
||||
|
||||
def test_search_success(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.post", return_value=_response(200, {"success": True, "data": [{"url": "x"}]}))
|
||||
assert app.search("python")["success"] is True
|
||||
|
||||
def test_search_warning_failure(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.post", return_value=_response(200, {"success": False, "warning": "bad search"}))
|
||||
with pytest.raises(Exception, match="bad search"):
|
||||
app.search("python")
|
||||
|
||||
def test_search_known_http_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"}))
|
||||
assert app.search("python") == {}
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_search_unknown_http_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.post", return_value=_response(418, text="teapot"))
|
||||
with pytest.raises(Exception, match="Failed to perform search. Status code: 418"):
|
||||
app.search("python")
|
||||
|
||||
|
||||
def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture):
|
||||
api_key = "fc-"
|
||||
app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/")
|
||||
mock_post = mocker.patch("httpx.post")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
mock_resp.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_resp
|
||||
class TestFirecrawlWebExtractor:
|
||||
def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_crawl_url_data",
|
||||
return_value={
|
||||
"markdown": "crawl content",
|
||||
"source_url": "https://example.com",
|
||||
"description": "desc",
|
||||
"title": "title",
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
app.scrape_url("https://example.com")
|
||||
extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
|
||||
docs = extractor.extract()
|
||||
|
||||
# Should not raise a JSONDecodeError; current behavior reports status code only
|
||||
assert str(excinfo.value) == "Failed to scrape URL. Status code: 404"
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "crawl content"
|
||||
assert docs[0].metadata["source_url"] == "https://example.com"
|
||||
|
||||
def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_crawl_url_data",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
|
||||
assert extractor.extract() == []
|
||||
|
||||
def test_extract_scrape_mode_returns_document(self, mocker: MockerFixture):
|
||||
mock_scrape = mocker.patch(
|
||||
"core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_scrape_url_data",
|
||||
return_value={
|
||||
"markdown": "scrape content",
|
||||
"source_url": "https://example.com",
|
||||
"description": "desc",
|
||||
"title": "title",
|
||||
},
|
||||
)
|
||||
|
||||
extractor = FirecrawlWebExtractor(
|
||||
"https://example.com", "job-1", "tenant-1", mode="scrape", only_main_content=False
|
||||
)
|
||||
docs = extractor.extract()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "scrape content"
|
||||
mock_scrape.assert_called_once_with("firecrawl", "https://example.com", "tenant-1", False)
|
||||
|
||||
def test_extract_unknown_mode_returns_empty(self):
|
||||
extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="unknown")
|
||||
assert extractor.extract() == []
|
||||
|
||||
@ -0,0 +1,95 @@
|
||||
import csv
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import core.rag.extractor.csv_extractor as csv_module
|
||||
from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
|
||||
|
||||
class _ManagedStringIO(io.StringIO):
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
|
||||
class TestCSVExtractor:
|
||||
def test_extract_success_with_source_column(self, tmp_path):
|
||||
file_path = tmp_path / "data.csv"
|
||||
file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8")
|
||||
|
||||
extractor = CSVExtractor(str(file_path), source_column="id")
|
||||
docs = extractor.extract()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "id: source-1;body: hello"
|
||||
assert docs[0].metadata == {"source": "source-1", "row": 0}
|
||||
|
||||
def test_extract_raises_when_source_column_missing(self, tmp_path):
|
||||
file_path = tmp_path / "data.csv"
|
||||
file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8")
|
||||
|
||||
extractor = CSVExtractor(str(file_path), source_column="missing_col")
|
||||
|
||||
with pytest.raises(ValueError, match="Source column 'missing_col' not found"):
|
||||
extractor.extract()
|
||||
|
||||
def test_extract_wraps_unicode_error_when_autodetect_disabled(self, monkeypatch):
|
||||
extractor = CSVExtractor("dummy.csv", autodetect_encoding=False)
|
||||
|
||||
def raise_decode(*args, **kwargs):
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error")
|
||||
|
||||
monkeypatch.setattr("builtins.open", raise_decode)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Error loading dummy.csv"):
|
||||
extractor.extract()
|
||||
|
||||
def test_extract_autodetect_encoding_success(self, monkeypatch):
|
||||
extractor = CSVExtractor("dummy.csv", autodetect_encoding=True)
|
||||
attempted_encodings: list[str | None] = []
|
||||
|
||||
def fake_open(path, newline="", encoding=None):
|
||||
attempted_encodings.append(encoding)
|
||||
if encoding is None:
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error")
|
||||
if encoding == "bad":
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error")
|
||||
return _ManagedStringIO("id,body\nsource-1,hello\n")
|
||||
|
||||
monkeypatch.setattr("builtins.open", fake_open)
|
||||
monkeypatch.setattr(
|
||||
csv_module,
|
||||
"detect_file_encodings",
|
||||
lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")],
|
||||
)
|
||||
|
||||
docs = extractor.extract()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "id: source-1;body: hello"
|
||||
assert attempted_encodings == [None, "bad", "utf-8"]
|
||||
|
||||
def test_extract_autodetect_encoding_all_attempts_fail_returns_empty(self, monkeypatch):
|
||||
extractor = CSVExtractor("dummy.csv", autodetect_encoding=True)
|
||||
|
||||
def always_raise(*args, **kwargs):
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error")
|
||||
|
||||
monkeypatch.setattr("builtins.open", always_raise)
|
||||
monkeypatch.setattr(csv_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="bad")])
|
||||
|
||||
assert extractor.extract() == []
|
||||
|
||||
def test_read_from_file_re_raises_csv_error(self, monkeypatch):
|
||||
extractor = CSVExtractor("dummy.csv")
|
||||
|
||||
monkeypatch.setattr(pd, "read_csv", lambda *args, **kwargs: (_ for _ in ()).throw(csv.Error("bad csv")))
|
||||
|
||||
with pytest.raises(csv.Error, match="bad csv"):
|
||||
extractor._read_from_file(io.StringIO("x"))
|
||||
117
api/tests/unit_tests/core/rag/extractor/test_excel_extractor.py
Normal file
117
api/tests/unit_tests/core/rag/extractor/test_excel_extractor.py
Normal file
@ -0,0 +1,117 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import core.rag.extractor.excel_extractor as excel_module
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
|
||||
|
||||
class _FakeCell:
|
||||
def __init__(self, value, hyperlink=None):
|
||||
self.value = value
|
||||
self.hyperlink = hyperlink
|
||||
|
||||
|
||||
class _FakeSheet:
|
||||
def __init__(self, header_rows, data_rows):
|
||||
self._header_rows = header_rows
|
||||
self._data_rows = data_rows
|
||||
|
||||
def iter_rows(self, min_row=1, max_row=None, max_col=None, values_only=False):
|
||||
if values_only:
|
||||
for row in self._header_rows:
|
||||
yield tuple(row)
|
||||
return
|
||||
|
||||
for row in self._data_rows:
|
||||
if max_col is not None:
|
||||
yield tuple(row[:max_col])
|
||||
else:
|
||||
yield tuple(row)
|
||||
|
||||
|
||||
class _FakeWorkbook:
|
||||
def __init__(self, sheets):
|
||||
self._sheets = sheets
|
||||
self.sheetnames = list(sheets.keys())
|
||||
self.closed = False
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._sheets[key]
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class TestExcelExtractor:
|
||||
def test_extract_xlsx_with_hyperlinks_and_sheet_skip(self, monkeypatch):
|
||||
sheet_with_data = _FakeSheet(
|
||||
header_rows=[("Name", "Link")],
|
||||
data_rows=[
|
||||
(_FakeCell("Alice"), _FakeCell("Doc", hyperlink=SimpleNamespace(target="https://example.com/doc"))),
|
||||
(_FakeCell(None), _FakeCell(123)),
|
||||
(_FakeCell(None), _FakeCell(None)),
|
||||
],
|
||||
)
|
||||
empty_sheet = _FakeSheet(header_rows=[(None, None)], data_rows=[])
|
||||
|
||||
workbook = _FakeWorkbook({"Data": sheet_with_data, "Empty": empty_sheet})
|
||||
monkeypatch.setattr(excel_module, "load_workbook", lambda *args, **kwargs: workbook)
|
||||
|
||||
extractor = ExcelExtractor("/tmp/sample.xlsx")
|
||||
docs = extractor.extract()
|
||||
|
||||
assert workbook.closed is True
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == '"Name":"Alice";"Link":"[Doc](https://example.com/doc)"'
|
||||
assert docs[1].page_content == '"Name":"";"Link":"123"'
|
||||
assert all(doc.metadata["source"] == "/tmp/sample.xlsx" for doc in docs)
|
||||
|
||||
def test_extract_xls_path(self, monkeypatch):
|
||||
class FakeExcelFile:
|
||||
sheet_names = ["Sheet1"]
|
||||
|
||||
def parse(self, sheet_name):
|
||||
assert sheet_name == "Sheet1"
|
||||
return pd.DataFrame([{"A": "x", "B": 1}, {"A": None, "B": None}])
|
||||
|
||||
monkeypatch.setattr(pd, "ExcelFile", lambda path, engine=None: FakeExcelFile())
|
||||
|
||||
extractor = ExcelExtractor("/tmp/sample.xls")
|
||||
docs = extractor.extract()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == '"A":"x";"B":"1.0"'
|
||||
assert docs[0].metadata == {"source": "/tmp/sample.xls"}
|
||||
|
||||
def test_extract_unsupported_extension_raises(self):
|
||||
extractor = ExcelExtractor("/tmp/sample.txt")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported file extension"):
|
||||
extractor.extract()
|
||||
|
||||
def test_find_header_and_columns_prefers_first_row_with_two_columns(self):
|
||||
sheet = _FakeSheet(
|
||||
header_rows=[(None, None, None), ("A", "B", None), ("X", None, None)],
|
||||
data_rows=[],
|
||||
)
|
||||
extractor = ExcelExtractor("dummy.xlsx")
|
||||
|
||||
header_row_idx, column_map, max_col_idx = extractor._find_header_and_columns(sheet)
|
||||
|
||||
assert header_row_idx == 2
|
||||
assert column_map == {0: "A", 1: "B"}
|
||||
assert max_col_idx == 2
|
||||
|
||||
def test_find_header_and_columns_fallback_and_empty_case(self):
|
||||
extractor = ExcelExtractor("dummy.xlsx")
|
||||
|
||||
fallback_sheet = _FakeSheet(header_rows=[("Only", None), (None, "Second")], data_rows=[])
|
||||
row_idx, column_map, max_col_idx = extractor._find_header_and_columns(fallback_sheet)
|
||||
assert row_idx == 1
|
||||
assert column_map == {0: "Only"}
|
||||
assert max_col_idx == 1
|
||||
|
||||
empty_sheet = _FakeSheet(header_rows=[(None, None)], data_rows=[])
|
||||
assert extractor._find_header_and_columns(empty_sheet) == (0, {}, 0)
|
||||
@ -0,0 +1,272 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.extractor.extract_processor as processor_module
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class _ExtractorFactory:
|
||||
def __init__(self) -> None:
|
||||
self.calls = []
|
||||
|
||||
def make(self, name: str) -> type[object]:
|
||||
calls = self.calls
|
||||
|
||||
class DummyExtractor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
calls.append((name, args, kwargs))
|
||||
|
||||
def extract(self):
|
||||
return [Document(page_content=f"extracted-by-{name}")]
|
||||
|
||||
return DummyExtractor
|
||||
|
||||
|
||||
def _patch_all_extractors(monkeypatch) -> _ExtractorFactory:
|
||||
factory = _ExtractorFactory()
|
||||
|
||||
for cls_name in [
|
||||
"CSVExtractor",
|
||||
"ExcelExtractor",
|
||||
"FirecrawlWebExtractor",
|
||||
"HtmlExtractor",
|
||||
"JinaReaderWebExtractor",
|
||||
"MarkdownExtractor",
|
||||
"NotionExtractor",
|
||||
"PdfExtractor",
|
||||
"TextExtractor",
|
||||
"UnstructuredEmailExtractor",
|
||||
"UnstructuredEpubExtractor",
|
||||
"UnstructuredMarkdownExtractor",
|
||||
"UnstructuredMsgExtractor",
|
||||
"UnstructuredPPTExtractor",
|
||||
"UnstructuredPPTXExtractor",
|
||||
"UnstructuredWordExtractor",
|
||||
"UnstructuredXmlExtractor",
|
||||
"WaterCrawlWebExtractor",
|
||||
"WordExtractor",
|
||||
]:
|
||||
monkeypatch.setattr(processor_module, cls_name, factory.make(cls_name))
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
class TestExtractProcessorLoaders:
|
||||
def test_load_from_upload_file_return_docs_and_text(self, monkeypatch):
|
||||
monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
monkeypatch.setattr(
|
||||
ExtractProcessor,
|
||||
"extract",
|
||||
lambda extract_setting, is_automatic=False, file_path=None: [
|
||||
Document(page_content="doc-1"),
|
||||
Document(page_content="doc-2"),
|
||||
],
|
||||
)
|
||||
|
||||
upload_file = SimpleNamespace(key="file.txt")
|
||||
|
||||
docs = ExtractProcessor.load_from_upload_file(upload_file=upload_file, return_text=False)
|
||||
text = ExtractProcessor.load_from_upload_file(upload_file=upload_file, return_text=True)
|
||||
|
||||
assert len(docs) == 2
|
||||
assert text == "doc-1\ndoc-2"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("url", "headers", "expected_suffix"),
|
||||
[
|
||||
("https://example.com/file.txt", {"Content-Type": "text/plain"}, ".txt"),
|
||||
("https://example.com/no_suffix", {"Content-Type": "application/pdf"}, ".pdf"),
|
||||
(
|
||||
"https://example.com/no_suffix",
|
||||
{"Content-Disposition": 'attachment; filename="report.md"'},
|
||||
".md",
|
||||
),
|
||||
(
|
||||
"https://example.com/no_suffix",
|
||||
{"Content-Disposition": 'attachment; filename="report"'},
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_load_from_url_builds_temp_file_with_correct_suffix(self, monkeypatch, url, headers, expected_suffix):
|
||||
response = SimpleNamespace(headers=headers, content=b"body")
|
||||
monkeypatch.setattr(processor_module.ssrf_proxy, "get", lambda *args, **kwargs: response)
|
||||
monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_extract(extract_setting, is_automatic=False, file_path=None):
|
||||
key = "file_path_docs" if "file_path_docs" not in captured else "file_path_text"
|
||||
captured[key] = file_path
|
||||
return [Document(page_content="u1"), Document(page_content="u2")]
|
||||
|
||||
monkeypatch.setattr(ExtractProcessor, "extract", fake_extract)
|
||||
|
||||
docs = ExtractProcessor.load_from_url(url, return_text=False)
|
||||
assert captured["file_path_docs"].endswith(expected_suffix)
|
||||
|
||||
text = ExtractProcessor.load_from_url(url, return_text=True)
|
||||
assert captured["file_path_text"].endswith(expected_suffix)
|
||||
|
||||
assert len(docs) == 2
|
||||
assert text == "u1\nu2"
|
||||
|
||||
|
||||
class TestExtractProcessorFileRouting:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _set_unstructured_config(self, monkeypatch):
|
||||
monkeypatch.setattr(processor_module.dify_config, "UNSTRUCTURED_API_URL", "https://unstructured")
|
||||
monkeypatch.setattr(processor_module.dify_config, "UNSTRUCTURED_API_KEY", "key")
|
||||
|
||||
def _run_extract_for_extension(self, monkeypatch, extension: str, etl_type: str, is_automatic: bool = False):
|
||||
factory = _patch_all_extractors(monkeypatch)
|
||||
monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", etl_type)
|
||||
|
||||
def fake_download(key: str, local_path: str):
|
||||
Path(local_path).write_text("content", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(processor_module.storage, "download", fake_download)
|
||||
monkeypatch.setattr(processor_module.tempfile, "_get_candidate_names", lambda: iter(["candidate-name"]))
|
||||
|
||||
setting = SimpleNamespace(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=SimpleNamespace(key=f"uploaded{extension}", tenant_id="tenant-1", created_by="user-1"),
|
||||
)
|
||||
|
||||
docs = ExtractProcessor.extract(setting, is_automatic=is_automatic)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content.startswith("extracted-by-")
|
||||
return factory.calls[-1][0], factory.calls[-1][1], factory.calls[-1][2]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("extension", "expected_extractor", "is_automatic"),
|
||||
[
|
||||
(".xlsx", "ExcelExtractor", False),
|
||||
(".xls", "ExcelExtractor", False),
|
||||
(".pdf", "PdfExtractor", False),
|
||||
(".md", "UnstructuredMarkdownExtractor", True),
|
||||
(".mdx", "MarkdownExtractor", False),
|
||||
(".htm", "HtmlExtractor", False),
|
||||
(".html", "HtmlExtractor", False),
|
||||
(".docx", "WordExtractor", False),
|
||||
(".doc", "UnstructuredWordExtractor", False),
|
||||
(".csv", "CSVExtractor", False),
|
||||
(".msg", "UnstructuredMsgExtractor", False),
|
||||
(".eml", "UnstructuredEmailExtractor", False),
|
||||
(".ppt", "UnstructuredPPTExtractor", False),
|
||||
(".pptx", "UnstructuredPPTXExtractor", False),
|
||||
(".xml", "UnstructuredXmlExtractor", False),
|
||||
(".epub", "UnstructuredEpubExtractor", False),
|
||||
(".txt", "TextExtractor", False),
|
||||
],
|
||||
)
|
||||
def test_extract_routes_file_extensions_for_unstructured_mode(
|
||||
self, monkeypatch, extension, expected_extractor, is_automatic
|
||||
):
|
||||
extractor_name, args, kwargs = self._run_extract_for_extension(
|
||||
monkeypatch, extension, etl_type="Unstructured", is_automatic=is_automatic
|
||||
)
|
||||
|
||||
assert extractor_name == expected_extractor
|
||||
assert args
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("extension", "expected_extractor"),
|
||||
[
|
||||
(".xlsx", "ExcelExtractor"),
|
||||
(".pdf", "PdfExtractor"),
|
||||
(".markdown", "MarkdownExtractor"),
|
||||
(".html", "HtmlExtractor"),
|
||||
(".docx", "WordExtractor"),
|
||||
(".csv", "CSVExtractor"),
|
||||
(".epub", "UnstructuredEpubExtractor"),
|
||||
(".txt", "TextExtractor"),
|
||||
],
|
||||
)
|
||||
def test_extract_routes_file_extensions_for_default_mode(self, monkeypatch, extension, expected_extractor):
|
||||
extractor_name, _, _ = self._run_extract_for_extension(monkeypatch, extension, etl_type="SelfHosted")
|
||||
|
||||
assert extractor_name == expected_extractor
|
||||
|
||||
def test_extract_requires_upload_file_when_file_path_not_provided(self):
|
||||
setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None)
|
||||
|
||||
with pytest.raises(AssertionError, match="upload_file is required"):
|
||||
ExtractProcessor.extract(setting)
|
||||
|
||||
|
||||
class TestExtractProcessorDatasourceRouting:
|
||||
def test_extract_routes_notion_datasource(self, monkeypatch):
|
||||
factory = _patch_all_extractors(monkeypatch)
|
||||
|
||||
notion_info = SimpleNamespace(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
document="doc",
|
||||
tenant_id="tenant",
|
||||
credential_id="cred",
|
||||
)
|
||||
setting = SimpleNamespace(datasource_type=DatasourceType.NOTION, notion_info=notion_info)
|
||||
|
||||
docs = ExtractProcessor.extract(setting)
|
||||
|
||||
assert docs[0].page_content == "extracted-by-NotionExtractor"
|
||||
assert factory.calls[-1][0] == "NotionExtractor"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "expected"),
|
||||
[
|
||||
("firecrawl", "FirecrawlWebExtractor"),
|
||||
("watercrawl", "WaterCrawlWebExtractor"),
|
||||
("jinareader", "JinaReaderWebExtractor"),
|
||||
],
|
||||
)
|
||||
def test_extract_routes_website_datasource_providers(self, monkeypatch, provider: str, expected: str):
|
||||
factory = _patch_all_extractors(monkeypatch)
|
||||
|
||||
website_info = SimpleNamespace(
|
||||
provider=provider,
|
||||
url="https://example.com",
|
||||
job_id="job",
|
||||
tenant_id="tenant",
|
||||
mode="crawl",
|
||||
only_main_content=True,
|
||||
)
|
||||
setting = SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=website_info)
|
||||
|
||||
docs = ExtractProcessor.extract(setting)
|
||||
assert docs[0].page_content == f"extracted-by-{expected}"
|
||||
assert factory.calls[-1][0] == expected
|
||||
|
||||
def test_extract_unsupported_website_provider(self):
|
||||
bad_provider = SimpleNamespace(
|
||||
provider="unknown",
|
||||
url="https://example.com",
|
||||
job_id="job",
|
||||
tenant_id="tenant",
|
||||
mode="crawl",
|
||||
only_main_content=True,
|
||||
)
|
||||
setting = SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=bad_provider)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported website provider"):
|
||||
ExtractProcessor.extract(setting)
|
||||
|
||||
def test_extract_unsupported_datasource_type(self):
|
||||
with pytest.raises(ValueError, match="Unsupported datasource type"):
|
||||
ExtractProcessor.extract(SimpleNamespace(datasource_type="unknown"))
|
||||
|
||||
def test_extract_requires_notion_info(self):
|
||||
with pytest.raises(AssertionError, match="notion_info is required"):
|
||||
ExtractProcessor.extract(SimpleNamespace(datasource_type=DatasourceType.NOTION, notion_info=None))
|
||||
|
||||
def test_extract_requires_website_info(self):
|
||||
with pytest.raises(AssertionError, match="website_info is required"):
|
||||
ExtractProcessor.extract(SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=None))
|
||||
@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
|
||||
|
||||
class _CallsBaseExtractor(BaseExtractor):
|
||||
def extract(self):
|
||||
return super().extract()
|
||||
|
||||
|
||||
class _ConcreteExtractor(BaseExtractor):
|
||||
def extract(self):
|
||||
return ["ok"]
|
||||
|
||||
|
||||
class TestBaseExtractor:
|
||||
def test_extract_default_raises_not_implemented(self):
|
||||
extractor = _CallsBaseExtractor()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
extractor.extract()
|
||||
|
||||
def test_concrete_extractor_can_override(self):
|
||||
extractor = _ConcreteExtractor()
|
||||
|
||||
assert extractor.extract() == ["ok"]
|
||||
@ -1,10 +1,55 @@
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.rag.extractor.helpers import FileEncoding, detect_file_encodings
|
||||
import pytest
|
||||
|
||||
from core.rag.extractor import helpers
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
|
||||
|
||||
def test_detect_file_encodings() -> None:
|
||||
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp:
|
||||
temp.write("Shared data")
|
||||
temp_path = temp.name
|
||||
assert detect_file_encodings(temp_path) == [FileEncoding(encoding="utf_8", confidence=0.0, language="Unknown")]
|
||||
class TestHelpers:
|
||||
def test_detect_file_encodings(self) -> None:
|
||||
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp:
|
||||
temp.write("Shared data")
|
||||
temp.flush()
|
||||
temp_path = temp.name
|
||||
encodings = detect_file_encodings(temp_path)
|
||||
|
||||
assert len(encodings) == 1
|
||||
assert encodings[0].encoding in {"utf_8", "ascii"}
|
||||
assert encodings[0].confidence == 0.0
|
||||
# Assert the language field for full coverage
|
||||
assert encodings[0].language is not None
|
||||
|
||||
def test_detect_file_encodings_timeout(self, monkeypatch):
|
||||
class FakeFuture:
|
||||
def result(self, timeout=None):
|
||||
raise helpers.concurrent.futures.TimeoutError()
|
||||
|
||||
class FakeExecutor:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def submit(self, fn, file_path):
|
||||
return FakeFuture()
|
||||
|
||||
monkeypatch.setattr(helpers.concurrent.futures, "ThreadPoolExecutor", lambda: FakeExecutor())
|
||||
|
||||
with pytest.raises(TimeoutError, match="Timeout reached while detecting encoding"):
|
||||
detect_file_encodings("file.txt", timeout=1)
|
||||
|
||||
def test_detect_file_encodings_raises_when_encoding_not_detected(self, monkeypatch):
|
||||
class FakeResult:
|
||||
encoding = None
|
||||
coherence = 0.0
|
||||
language = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
helpers.charset_normalizer, "from_path", lambda _: SimpleNamespace(best=lambda: FakeResult())
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Could not detect encoding"):
|
||||
detect_file_encodings("file.txt")
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
|
||||
|
||||
class TestHtmlExtractor:
|
||||
def test_extract_returns_text_content(self, tmp_path):
|
||||
file_path = tmp_path / "sample.html"
|
||||
file_path.write_text("<html><body><h1>Title</h1><p>Hello</p></body></html>", encoding="utf-8")
|
||||
|
||||
extractor = HtmlExtractor(str(file_path))
|
||||
docs = extractor.extract()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert "".join(docs[0].page_content.split()) == "TitleHello"
|
||||
|
||||
def test_load_as_text_strips_whitespace_and_handles_empty(self, tmp_path):
|
||||
file_path = tmp_path / "sample.html"
|
||||
file_path.write_text("<html><body> \n </body></html>", encoding="utf-8")
|
||||
|
||||
extractor = HtmlExtractor(str(file_path))
|
||||
|
||||
assert extractor._load_as_text() == ""
|
||||
@ -0,0 +1,47 @@
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
|
||||
|
||||
|
||||
class TestJinaReaderWebExtractor:
|
||||
def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data",
|
||||
return_value={
|
||||
"content": "markdown-content",
|
||||
"url": "https://example.com",
|
||||
"description": "desc",
|
||||
"title": "title",
|
||||
},
|
||||
)
|
||||
|
||||
extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
|
||||
docs = extractor.extract()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "markdown-content"
|
||||
assert docs[0].metadata == {
|
||||
"source_url": "https://example.com",
|
||||
"description": "desc",
|
||||
"title": "title",
|
||||
}
|
||||
|
||||
def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
|
||||
|
||||
assert extractor.extract() == []
|
||||
|
||||
def test_extract_non_crawl_mode_returns_empty(self, mocker: MockerFixture):
|
||||
mock_get_crawl = mocker.patch(
|
||||
"core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data",
|
||||
return_value={"content": "unused"},
|
||||
)
|
||||
extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape")
|
||||
|
||||
assert extractor.extract() == []
|
||||
mock_get_crawl.assert_not_called()
|
||||
@ -1,8 +1,15 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.extractor.markdown_extractor as markdown_module
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
|
||||
|
||||
def test_markdown_to_tups():
|
||||
markdown = """
|
||||
class TestMarkdownExtractor:
|
||||
def test_markdown_to_tups(self):
|
||||
markdown = """
|
||||
this is some text without header
|
||||
|
||||
# title 1
|
||||
@ -11,12 +18,113 @@ this is balabala text
|
||||
## title 2
|
||||
this is more specific text.
|
||||
"""
|
||||
extractor = MarkdownExtractor(file_path="dummy_path")
|
||||
updated_output = extractor.markdown_to_tups(markdown)
|
||||
assert len(updated_output) == 3
|
||||
key, header_value = updated_output[0]
|
||||
assert key == None
|
||||
assert header_value.strip() == "this is some text without header"
|
||||
title_1, value = updated_output[1]
|
||||
assert title_1.strip() == "title 1"
|
||||
assert value.strip() == "this is balabala text"
|
||||
extractor = MarkdownExtractor(file_path="dummy_path")
|
||||
updated_output = extractor.markdown_to_tups(markdown)
|
||||
|
||||
assert len(updated_output) == 3
|
||||
key, header_value = updated_output[0]
|
||||
assert key is None
|
||||
assert header_value.strip() == "this is some text without header"
|
||||
|
||||
title_1, value = updated_output[1]
|
||||
assert title_1.strip() == "title 1"
|
||||
assert value.strip() == "this is balabala text"
|
||||
|
||||
def test_markdown_to_tups_keeps_code_block_headers_literal(self):
|
||||
markdown = """# Header
|
||||
before
|
||||
```python
|
||||
# this is not a heading
|
||||
print('x')
|
||||
```
|
||||
after
|
||||
"""
|
||||
extractor = MarkdownExtractor(file_path="dummy_path")
|
||||
|
||||
tups = extractor.markdown_to_tups(markdown)
|
||||
|
||||
assert len(tups) == 2
|
||||
assert tups[1][0] == "Header"
|
||||
assert "# this is not a heading" in tups[1][1]
|
||||
|
||||
def test_remove_images_and_hyperlinks(self):
|
||||
extractor = MarkdownExtractor(file_path="dummy_path")
|
||||
|
||||
with_images = "before ![[image.png]] after"
|
||||
with_links = "[OpenAI](https://openai.com)"
|
||||
|
||||
assert extractor.remove_images(with_images) == "before after"
|
||||
assert extractor.remove_hyperlinks(with_links) == "OpenAI"
|
||||
|
||||
def test_parse_tups_reads_file_and_applies_options(self, tmp_path):
|
||||
markdown_file = tmp_path / "doc.md"
|
||||
markdown_file.write_text("# Header\nText with [link](https://example.com) and ![[img.png]]", encoding="utf-8")
|
||||
|
||||
extractor = MarkdownExtractor(
|
||||
file_path=str(markdown_file),
|
||||
remove_hyperlinks=True,
|
||||
remove_images=True,
|
||||
autodetect_encoding=False,
|
||||
)
|
||||
|
||||
tups = extractor.parse_tups(str(markdown_file))
|
||||
|
||||
assert len(tups) == 2
|
||||
assert tups[1][0] == "Header"
|
||||
assert "[link]" not in tups[1][1]
|
||||
assert "img.png" not in tups[1][1]
|
||||
|
||||
def test_parse_tups_autodetects_encoding_after_decode_error(self, monkeypatch):
|
||||
extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=True)
|
||||
|
||||
calls: list[str | None] = []
|
||||
|
||||
def fake_read_text(self, encoding=None):
|
||||
calls.append(encoding)
|
||||
if encoding is None:
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail")
|
||||
if encoding == "bad-encoding":
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail")
|
||||
return "# H\ncontent"
|
||||
|
||||
monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True)
|
||||
monkeypatch.setattr(
|
||||
markdown_module,
|
||||
"detect_file_encodings",
|
||||
lambda _: [SimpleNamespace(encoding="bad-encoding"), SimpleNamespace(encoding="utf-8")],
|
||||
)
|
||||
|
||||
tups = extractor.parse_tups("dummy_path")
|
||||
|
||||
assert len(tups) == 2
|
||||
assert calls == [None, "bad-encoding", "utf-8"]
|
||||
|
||||
def test_parse_tups_decode_error_with_autodetect_disabled_raises(self, monkeypatch):
|
||||
extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=False)
|
||||
|
||||
def raise_decode(self, encoding=None):
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail")
|
||||
|
||||
monkeypatch.setattr(Path, "read_text", raise_decode, raising=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Error loading dummy_path"):
|
||||
extractor.parse_tups("dummy_path")
|
||||
|
||||
def test_parse_tups_other_exceptions_are_wrapped(self, monkeypatch):
|
||||
extractor = MarkdownExtractor(file_path="dummy_path")
|
||||
|
||||
def raise_other(self, encoding=None):
|
||||
raise OSError("disk error")
|
||||
|
||||
monkeypatch.setattr(Path, "read_text", raise_other, raising=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Error loading dummy_path"):
|
||||
extractor.parse_tups("dummy_path")
|
||||
|
||||
def test_extract_builds_documents_for_header_and_non_header(self, monkeypatch):
|
||||
extractor = MarkdownExtractor(file_path="dummy_path")
|
||||
monkeypatch.setattr(extractor, "parse_tups", lambda _: [(None, "plain"), ("Header", "value")])
|
||||
|
||||
docs = extractor.extract()
|
||||
|
||||
assert [doc.page_content for doc in docs] == ["plain", "\n\nHeader\nvalue"]
|
||||
|
||||
@ -1,93 +1,499 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor import notion_extractor
|
||||
|
||||
user_id = "user1"
|
||||
database_id = "database1"
|
||||
page_id = "page1"
|
||||
|
||||
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x"
|
||||
)
|
||||
|
||||
|
||||
def _generate_page(page_title: str):
|
||||
return {
|
||||
"object": "page",
|
||||
"id": page_id,
|
||||
"properties": {
|
||||
"Page": {
|
||||
"type": "title",
|
||||
"title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _generate_block(block_id: str, block_type: str, block_text: str):
|
||||
return {
|
||||
"object": "block",
|
||||
"id": block_id,
|
||||
"parent": {"type": "page_id", "page_id": page_id},
|
||||
"type": block_type,
|
||||
"has_children": False,
|
||||
block_type: {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": block_text},
|
||||
"plain_text": block_text,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _mock_response(data):
|
||||
def _mock_response(data, status_code: int = 200, text: str = ""):
|
||||
response = mock.Mock()
|
||||
response.status_code = 200
|
||||
response.status_code = status_code
|
||||
response.text = text
|
||||
response.json.return_value = data
|
||||
return response
|
||||
|
||||
|
||||
def _remove_multiple_new_lines(text):
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
return text.strip()
|
||||
class TestNotionExtractorInitAndPublicMethods:
|
||||
def test_init_with_explicit_token(self):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
assert extractor._notion_access_token == "token"
|
||||
|
||||
def test_init_falls_back_to_env_token_when_credential_lookup_fails(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
notion_extractor.NotionExtractor,
|
||||
"_get_access_token",
|
||||
classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))),
|
||||
)
|
||||
monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", "env-token", raising=False)
|
||||
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
credential_id="cred",
|
||||
)
|
||||
|
||||
assert extractor._notion_access_token == "env-token"
|
||||
|
||||
def test_init_raises_if_no_credential_and_no_env_token(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
notion_extractor.NotionExtractor,
|
||||
"_get_access_token",
|
||||
classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))),
|
||||
)
|
||||
monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", None, raising=False)
|
||||
|
||||
with pytest.raises(ValueError, match="Must specify `integration_token`"):
|
||||
notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
credential_id="cred",
|
||||
)
|
||||
|
||||
def test_extract_updates_last_edited_and_loads_documents(self, monkeypatch):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
update_mock = mock.Mock()
|
||||
load_mock = mock.Mock(return_value=[SimpleNamespace(page_content="doc")])
|
||||
monkeypatch.setattr(extractor, "update_last_edited_time", update_mock)
|
||||
monkeypatch.setattr(extractor, "_load_data_as_documents", load_mock)
|
||||
|
||||
docs = extractor.extract()
|
||||
|
||||
update_mock.assert_called_once_with(None)
|
||||
load_mock.assert_called_once_with("obj", "page")
|
||||
assert len(docs) == 1
|
||||
|
||||
def test_load_data_as_documents_page_database_and_invalid(self, monkeypatch):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(extractor, "_get_notion_block_data", lambda _: ["line1", "line2"])
|
||||
page_docs = extractor._load_data_as_documents("page-id", "page")
|
||||
assert page_docs[0].page_content == "line1\nline2"
|
||||
|
||||
monkeypatch.setattr(extractor, "_get_notion_database_data", lambda _: [SimpleNamespace(page_content="db")])
|
||||
db_docs = extractor._load_data_as_documents("db-id", "database")
|
||||
assert db_docs[0].page_content == "db"
|
||||
|
||||
with pytest.raises(ValueError, match="notion page type not supported"):
|
||||
extractor._load_data_as_documents("obj", "unsupported")
|
||||
|
||||
|
||||
def test_notion_page(mocker: MockerFixture):
|
||||
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
|
||||
mocked_notion_page = {
|
||||
"object": "list",
|
||||
"results": [
|
||||
_generate_block("b1", "heading_1", texts[0]),
|
||||
_generate_block("b2", "heading_2", texts[1]),
|
||||
_generate_block("b3", "paragraph", texts[2]),
|
||||
_generate_block("b4", "heading_3", texts[3]),
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page))
|
||||
class TestNotionDatabase:
|
||||
def test_get_notion_database_data_parses_property_types_and_pagination(self, mocker: MockerFixture):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="database",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
page_docs = extractor._load_data_as_documents(page_id, "page")
|
||||
assert len(page_docs) == 1
|
||||
content = _remove_multiple_new_lines(page_docs[0].page_content)
|
||||
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
|
||||
first_page = {
|
||||
"results": [
|
||||
{
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "multi_select",
|
||||
"multi_select": [{"name": "A"}, {"name": "B"}],
|
||||
},
|
||||
"title_prop": {"type": "title", "title": [{"plain_text": "Title"}]},
|
||||
"empty_title": {"type": "title", "title": []},
|
||||
"rich": {"type": "rich_text", "rich_text": [{"plain_text": "RichText"}]},
|
||||
"empty_rich": {"type": "rich_text", "rich_text": []},
|
||||
"select_prop": {"type": "select", "select": {"name": "Selected"}},
|
||||
"empty_select": {"type": "select", "select": None},
|
||||
"status_prop": {"type": "status", "status": {"name": "Open"}},
|
||||
"empty_status": {"type": "status", "status": None},
|
||||
"number_prop": {"type": "number", "number": 10},
|
||||
"dict_prop": {"type": "date", "date": {"start": "2024-01-01", "end": None}},
|
||||
},
|
||||
"url": "https://notion.so/page-1",
|
||||
}
|
||||
],
|
||||
"has_more": True,
|
||||
"next_cursor": "cursor-2",
|
||||
}
|
||||
second_page = {"results": [], "has_more": False, "next_cursor": None}
|
||||
|
||||
mock_post = mocker.patch("httpx.post", side_effect=[_mock_response(first_page), _mock_response(second_page)])
|
||||
|
||||
docs = extractor._get_notion_database_data("db-1", query_dict={"filter": {"x": 1}})
|
||||
|
||||
assert len(docs) == 1
|
||||
content = docs[0].page_content
|
||||
assert "tags:['A', 'B']" in content
|
||||
assert "title_prop:Title" in content
|
||||
assert "rich:RichText" in content
|
||||
assert "number_prop:10" in content
|
||||
assert "dict_prop:start:2024-01-01" in content
|
||||
assert "Row Page URL:https://notion.so/page-1" in content
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
def test_get_notion_database_data_handles_missing_results_and_empty_content(self, mocker: MockerFixture):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="database",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
mocker.patch("httpx.post", return_value=_mock_response({"results": None}))
|
||||
assert extractor._get_notion_database_data("db-1") == []
|
||||
|
||||
def test_get_notion_database_data_requires_access_token(self):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="database",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
extractor._notion_access_token = None
|
||||
|
||||
with pytest.raises(AssertionError, match="Notion access token is required"):
|
||||
extractor._get_notion_database_data("db-1")
|
||||
|
||||
|
||||
def test_notion_database(mocker: MockerFixture):
|
||||
page_title_list = ["page1", "page2", "page3"]
|
||||
mocked_notion_database = {
|
||||
"object": "list",
|
||||
"results": [_generate_page(i) for i in page_title_list],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database))
|
||||
database_docs = extractor._load_data_as_documents(database_id, "database")
|
||||
assert len(database_docs) == 1
|
||||
content = _remove_multiple_new_lines(database_docs[0].page_content)
|
||||
assert content == "\n".join([f"Page:{i}" for i in page_title_list])
|
||||
class TestNotionBlocks:
|
||||
def test_get_notion_block_data_success_with_table_headings_children_and_pagination(self, mocker: MockerFixture):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
first_response = {
|
||||
"results": [
|
||||
{"type": "table", "id": "tbl-1", "has_children": False, "table": {}},
|
||||
{
|
||||
"type": "heading_1",
|
||||
"id": "h1",
|
||||
"has_children": False,
|
||||
"heading_1": {"rich_text": [{"text": {"content": "Heading"}}]},
|
||||
},
|
||||
{
|
||||
"type": "paragraph",
|
||||
"id": "p1",
|
||||
"has_children": True,
|
||||
"paragraph": {"rich_text": [{"text": {"content": "Paragraph"}}]},
|
||||
},
|
||||
{
|
||||
"type": "child_page",
|
||||
"id": "cp1",
|
||||
"has_children": True,
|
||||
"child_page": {"rich_text": []},
|
||||
},
|
||||
],
|
||||
"next_cursor": "cursor-2",
|
||||
}
|
||||
second_response = {
|
||||
"results": [
|
||||
{
|
||||
"type": "heading_2",
|
||||
"id": "h2",
|
||||
"has_children": False,
|
||||
"heading_2": {"rich_text": [{"text": {"content": "SubHeading"}}]},
|
||||
}
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
mocker.patch("httpx.request", side_effect=[_mock_response(first_response), _mock_response(second_response)])
|
||||
mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE")
|
||||
mocker.patch.object(extractor, "_read_block", return_value="CHILD")
|
||||
|
||||
lines = extractor._get_notion_block_data("page-1")
|
||||
|
||||
assert lines[0] == "TABLE\n\n"
|
||||
assert "# Heading" in lines[1]
|
||||
assert "Paragraph\nCHILD\n\n" in lines[2]
|
||||
assert "## SubHeading" in lines[-1]
|
||||
|
||||
def test_get_notion_block_data_handles_http_error_and_invalid_payload(self, mocker: MockerFixture):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
mocker.patch("httpx.request", side_effect=httpx.HTTPError("network"))
|
||||
with pytest.raises(ValueError, match="Error fetching Notion block data"):
|
||||
extractor._get_notion_block_data("page-1")
|
||||
|
||||
mocker.patch("httpx.request", return_value=_mock_response({"bad": "payload"}, status_code=200))
|
||||
with pytest.raises(ValueError, match="Error fetching Notion block data"):
|
||||
extractor._get_notion_block_data("page-1")
|
||||
|
||||
mocker.patch("httpx.request", return_value=_mock_response({"results": []}, status_code=500, text="boom"))
|
||||
with pytest.raises(ValueError, match="Error fetching Notion block data: boom"):
|
||||
extractor._get_notion_block_data("page-1")
|
||||
|
||||
def test_read_block_supports_heading_table_and_recursion(self, mocker: MockerFixture):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
root_payload = {
|
||||
"results": [
|
||||
{
|
||||
"type": "heading_2",
|
||||
"id": "h2",
|
||||
"has_children": False,
|
||||
"heading_2": {"rich_text": [{"text": {"content": "Root"}}]},
|
||||
},
|
||||
{
|
||||
"type": "paragraph",
|
||||
"id": "child-block",
|
||||
"has_children": True,
|
||||
"paragraph": {"rich_text": [{"text": {"content": "Parent"}}]},
|
||||
},
|
||||
{"type": "table", "id": "tbl-1", "has_children": False, "table": {}},
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
child_payload = {
|
||||
"results": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"id": "leaf",
|
||||
"has_children": False,
|
||||
"paragraph": {"rich_text": [{"text": {"content": "Child"}}]},
|
||||
}
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
mocker.patch("httpx.request", side_effect=[_mock_response(root_payload), _mock_response(child_payload)])
|
||||
mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE-MD")
|
||||
|
||||
content = extractor._read_block("root")
|
||||
|
||||
assert "## Root" in content
|
||||
assert "Parent" in content
|
||||
assert "Child" in content
|
||||
assert "TABLE-MD" in content
|
||||
|
||||
def test_read_block_breaks_on_missing_results(self, mocker: MockerFixture):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
mocker.patch("httpx.request", return_value=_mock_response({"results": None, "next_cursor": None}))
|
||||
|
||||
assert extractor._read_block("root") == ""
|
||||
|
||||
def test_read_table_rows_formats_markdown_with_pagination(self, mocker: MockerFixture):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
page_one = {
|
||||
"results": [
|
||||
{
|
||||
"table_row": {
|
||||
"cells": [
|
||||
[{"text": {"content": "H1"}}],
|
||||
[{"text": {"content": "H2"}}],
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"table_row": {
|
||||
"cells": [
|
||||
[{"text": {"content": "R1C1"}}],
|
||||
[{"text": {"content": "R1C2"}}],
|
||||
]
|
||||
}
|
||||
},
|
||||
],
|
||||
"next_cursor": "next",
|
||||
}
|
||||
page_two = {
|
||||
"results": [
|
||||
{
|
||||
"table_row": {
|
||||
"cells": [
|
||||
[{"text": {"content": "H1"}}],
|
||||
[],
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"table_row": {
|
||||
"cells": [
|
||||
[{"text": {"content": "R2C1"}}],
|
||||
[{"text": {"content": "R2C2"}}],
|
||||
]
|
||||
}
|
||||
},
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
mocker.patch("httpx.request", side_effect=[_mock_response(page_one), _mock_response(page_two)])
|
||||
|
||||
markdown = extractor._read_table_rows("tbl-1")
|
||||
|
||||
assert "| H1 | H2 |" in markdown
|
||||
assert "| R1C1 | R1C2 |" in markdown
|
||||
assert "| H1 | |" in markdown
|
||||
assert "| R2C1 | R2C2 |" in markdown
|
||||
|
||||
|
||||
class TestNotionMetadataAndCredentialMethods:
|
||||
def test_update_last_edited_time_no_document_model(self):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
assert extractor.update_last_edited_time(None) is None
|
||||
|
||||
def test_update_last_edited_time_updates_document_and_commits(self, monkeypatch):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
|
||||
class FakeDocumentModel:
|
||||
data_source_info = "data_source_info"
|
||||
|
||||
update_calls = []
|
||||
|
||||
class FakeQuery:
|
||||
def filter_by(self, **kwargs):
|
||||
return self
|
||||
|
||||
def update(self, payload):
|
||||
update_calls.append(payload)
|
||||
|
||||
class FakeSession:
|
||||
committed = False
|
||||
|
||||
def query(self, model):
|
||||
assert model is FakeDocumentModel
|
||||
return FakeQuery()
|
||||
|
||||
def commit(self):
|
||||
self.committed = True
|
||||
|
||||
fake_db = SimpleNamespace(session=FakeSession())
|
||||
monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel)
|
||||
monkeypatch.setattr(notion_extractor, "db", fake_db)
|
||||
monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z")
|
||||
|
||||
doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"})
|
||||
extractor.update_last_edited_time(doc_model)
|
||||
|
||||
assert update_calls
|
||||
assert fake_db.session.committed is True
|
||||
|
||||
def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture):
|
||||
extractor_page = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="page-id",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
request_mock = mocker.patch(
|
||||
"httpx.request", return_value=_mock_response({"last_edited_time": "2025-05-01T00:00:00.000Z"})
|
||||
)
|
||||
|
||||
assert extractor_page.get_notion_last_edited_time() == "2025-05-01T00:00:00.000Z"
|
||||
assert "pages/page-id" in request_mock.call_args[0][1]
|
||||
|
||||
extractor_db = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="db-id",
|
||||
notion_page_type="database",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
request_mock = mocker.patch(
|
||||
"httpx.request", return_value=_mock_response({"last_edited_time": "2025-06-01T00:00:00.000Z"})
|
||||
)
|
||||
|
||||
assert extractor_db.get_notion_last_edited_time() == "2025-06-01T00:00:00.000Z"
|
||||
assert "databases/db-id" in request_mock.call_args[0][1]
|
||||
|
||||
def test_get_notion_last_edited_time_requires_access_token(self):
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="ws",
|
||||
notion_obj_id="obj",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant",
|
||||
notion_access_token="token",
|
||||
)
|
||||
extractor._notion_access_token = None
|
||||
|
||||
with pytest.raises(AssertionError, match="Notion access token is required"):
|
||||
extractor.get_notion_last_edited_time()
|
||||
|
||||
def test_get_access_token_success_and_errors(self, monkeypatch):
|
||||
with pytest.raises(Exception, match="No credential id found"):
|
||||
notion_extractor.NotionExtractor._get_access_token("tenant", None)
|
||||
|
||||
class FakeProviderServiceMissing:
|
||||
def get_datasource_credentials(self, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceMissing)
|
||||
with pytest.raises(Exception, match="No notion credential found"):
|
||||
notion_extractor.NotionExtractor._get_access_token("tenant", "cred")
|
||||
|
||||
class FakeProviderServiceFound:
|
||||
def get_datasource_credentials(self, **kwargs):
|
||||
return {"integration_secret": "token-from-credential"}
|
||||
|
||||
monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceFound)
|
||||
|
||||
assert notion_extractor.NotionExtractor._get_access_token("tenant", "cred") == "token-from-credential"
|
||||
|
||||
@ -0,0 +1,79 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.extractor.text_extractor as text_module
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
|
||||
|
||||
class TestTextExtractor:
|
||||
def test_extract_success(self, tmp_path):
|
||||
file_path = tmp_path / "data.txt"
|
||||
file_path.write_text("hello world", encoding="utf-8")
|
||||
|
||||
extractor = TextExtractor(str(file_path))
|
||||
docs = extractor.extract()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "hello world"
|
||||
assert docs[0].metadata == {"source": str(file_path)}
|
||||
|
||||
def test_extract_autodetect_success_after_decode_error(self, monkeypatch):
|
||||
extractor = TextExtractor("dummy.txt", autodetect_encoding=True)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_read_text(self, encoding=None):
|
||||
calls.append(encoding)
|
||||
if encoding is None:
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode")
|
||||
if encoding == "bad":
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode")
|
||||
return "decoded text"
|
||||
|
||||
monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True)
|
||||
monkeypatch.setattr(
|
||||
text_module,
|
||||
"detect_file_encodings",
|
||||
lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")],
|
||||
)
|
||||
|
||||
docs = extractor.extract()
|
||||
|
||||
assert docs[0].page_content == "decoded text"
|
||||
assert calls == [None, "bad", "utf-8"]
|
||||
|
||||
def test_extract_autodetect_all_fail_raises_runtime_error(self, monkeypatch):
|
||||
extractor = TextExtractor("dummy.txt", autodetect_encoding=True)
|
||||
|
||||
def always_decode_error(self, encoding=None):
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode")
|
||||
|
||||
monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True)
|
||||
monkeypatch.setattr(text_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="latin-1")])
|
||||
|
||||
with pytest.raises(RuntimeError, match="all detected encodings failed"):
|
||||
extractor.extract()
|
||||
|
||||
def test_extract_decode_error_without_autodetect_raises_runtime_error(self, monkeypatch):
|
||||
extractor = TextExtractor("dummy.txt", autodetect_encoding=False)
|
||||
|
||||
def always_decode_error(self, encoding=None):
|
||||
raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode")
|
||||
|
||||
monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="specified encoding failed"):
|
||||
extractor.extract()
|
||||
|
||||
def test_extract_wraps_non_decode_exceptions(self, monkeypatch):
|
||||
extractor = TextExtractor("dummy.txt")
|
||||
|
||||
def raise_other(self, encoding=None):
|
||||
raise OSError("io error")
|
||||
|
||||
monkeypatch.setattr(Path, "read_text", raise_other, raising=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Error loading dummy.txt"):
|
||||
extractor.extract()
|
||||
@ -3,9 +3,12 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from collections import UserDict
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from docx import Document
|
||||
from docx.oxml import OxmlElement
|
||||
from docx.oxml.ns import qn
|
||||
@ -136,7 +139,7 @@ def test_extract_images_from_docx(monkeypatch):
|
||||
monkeypatch.setattr(we, "UploadFile", FakeUploadFile)
|
||||
|
||||
# Patch external image fetcher
|
||||
def fake_get(url: str):
|
||||
def fake_get(url: str, **kwargs):
|
||||
assert url == "https://example.com/image.png"
|
||||
return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes)
|
||||
|
||||
@ -203,10 +206,8 @@ def test_extract_images_from_docx_uses_internal_files_url():
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
if original_files_url is not None:
|
||||
dify_config.FILES_URL = original_files_url
|
||||
if original_internal_files_url is not None:
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
dify_config.FILES_URL = original_files_url
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
||||
|
||||
def test_extract_hyperlinks(monkeypatch):
|
||||
@ -314,3 +315,313 @@ def test_extract_legacy_hyperlinks(monkeypatch):
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def test_init_rejects_invalid_url_status(monkeypatch):
|
||||
class FakeResponse:
|
||||
status_code = 404
|
||||
content = b""
|
||||
closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
fake_response = FakeResponse()
|
||||
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=lambda url, **kwargs: fake_response))
|
||||
|
||||
with pytest.raises(ValueError, match="returned status code 404"):
|
||||
WordExtractor("https://example.com/missing.docx", "tenant", "user")
|
||||
|
||||
assert fake_response.closed is True
|
||||
|
||||
|
||||
def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
|
||||
target_file = tmp_path / "expanded.docx"
|
||||
target_file.write_bytes(b"docx")
|
||||
|
||||
monkeypatch.setattr(we.os.path, "expanduser", lambda p: str(target_file))
|
||||
monkeypatch.setattr(
|
||||
we.os.path,
|
||||
"isfile",
|
||||
lambda p: p == str(target_file),
|
||||
)
|
||||
|
||||
extractor = WordExtractor("~/expanded.docx", "tenant", "user")
|
||||
assert extractor.file_path == str(target_file)
|
||||
|
||||
monkeypatch.setattr(we.os.path, "isfile", lambda p: False)
|
||||
with pytest.raises(ValueError, match="is not a valid file or url"):
|
||||
WordExtractor("not-a-file", "tenant", "user")
|
||||
|
||||
|
||||
def test_del_closes_temp_file():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor.temp_file = MagicMock()
|
||||
|
||||
WordExtractor.__del__(extractor)
|
||||
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
|
||||
|
||||
def test_extract_images_handles_invalid_external_cases(monkeypatch):
|
||||
class FakeTargetRef:
|
||||
def __contains__(self, item):
|
||||
return item == "image"
|
||||
|
||||
def split(self, sep):
|
||||
return [None]
|
||||
|
||||
rel_invalid_url = SimpleNamespace(is_external=True, target_ref="image-no-url")
|
||||
rel_request_error = SimpleNamespace(is_external=True, target_ref="https://example.com/image-error")
|
||||
rel_unknown_mime = SimpleNamespace(is_external=True, target_ref="https://example.com/image-unknown")
|
||||
rel_internal_none_ext = SimpleNamespace(is_external=False, target_ref=FakeTargetRef(), target_part=object())
|
||||
|
||||
doc = SimpleNamespace(
|
||||
part=SimpleNamespace(
|
||||
rels={
|
||||
"r1": rel_invalid_url,
|
||||
"r2": rel_request_error,
|
||||
"r3": rel_unknown_mime,
|
||||
"r4": rel_internal_none_ext,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def fake_get(url, **kwargs):
|
||||
if "image-error" in url:
|
||||
raise RuntimeError("network")
|
||||
return SimpleNamespace(status_code=200, headers={"Content-Type": "application/unknown"}, content=b"x")
|
||||
|
||||
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
|
||||
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=MagicMock()))
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None))
|
||||
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor.tenant_id = "tenant"
|
||||
extractor.user_id = "user"
|
||||
|
||||
result = extractor._extract_images_from_docx(doc)
|
||||
|
||||
assert result == {}
|
||||
db_stub.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_table_to_markdown_and_parse_helpers(monkeypatch):
|
||||
extractor = object.__new__(WordExtractor)
|
||||
|
||||
table = SimpleNamespace(
|
||||
rows=[
|
||||
SimpleNamespace(cells=[1, 2]),
|
||||
SimpleNamespace(cells=[3, 4]),
|
||||
]
|
||||
)
|
||||
parse_row_mock = MagicMock(side_effect=[["H1", "H2"], ["A", "B"]])
|
||||
monkeypatch.setattr(extractor, "_parse_row", parse_row_mock)
|
||||
|
||||
markdown = extractor._table_to_markdown(table, {})
|
||||
assert markdown == "| H1 | H2 |\n| --- | --- |\n| A | B |"
|
||||
|
||||
class FakeRunElement:
|
||||
def __init__(self, blips):
|
||||
self._blips = blips
|
||||
|
||||
def xpath(self, pattern):
|
||||
if pattern == ".//a:blip":
|
||||
return self._blips
|
||||
return []
|
||||
|
||||
class FakeBlip:
|
||||
def __init__(self, image_id):
|
||||
self.image_id = image_id
|
||||
|
||||
def get(self, key):
|
||||
return self.image_id
|
||||
|
||||
image_part = object()
|
||||
paragraph = SimpleNamespace(
|
||||
runs=[
|
||||
SimpleNamespace(element=FakeRunElement([FakeBlip(None), FakeBlip("ext"), FakeBlip("int")]), text=""),
|
||||
SimpleNamespace(element=FakeRunElement([]), text="plain"),
|
||||
],
|
||||
part=SimpleNamespace(
|
||||
rels={
|
||||
"ext": SimpleNamespace(is_external=True),
|
||||
"int": SimpleNamespace(is_external=False, target_part=image_part),
|
||||
}
|
||||
),
|
||||
)
|
||||
image_map = {"ext": "EXT-IMG", image_part: "INT-IMG"}
|
||||
assert extractor._parse_cell_paragraph(paragraph, image_map) == "EXT-IMGINT-IMGplain"
|
||||
|
||||
cell = SimpleNamespace(paragraphs=[paragraph, paragraph])
|
||||
assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain"
|
||||
|
||||
|
||||
def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch):
|
||||
extractor = object.__new__(WordExtractor)
|
||||
|
||||
ext_image_id = "ext-image"
|
||||
int_embed_id = "int-embed"
|
||||
shape_ext_id = "shape-ext"
|
||||
shape_int_id = "shape-int"
|
||||
|
||||
internal_part = object()
|
||||
shape_internal_part = object()
|
||||
|
||||
class Rels(UserDict):
|
||||
def get(self, key, default=None):
|
||||
if key == "link-bad":
|
||||
raise RuntimeError("cannot resolve relation")
|
||||
return super().get(key, default)
|
||||
|
||||
rels = Rels(
|
||||
{
|
||||
ext_image_id: SimpleNamespace(is_external=True, target_ref="https://img/ext.png"),
|
||||
int_embed_id: SimpleNamespace(is_external=False, target_part=internal_part),
|
||||
shape_ext_id: SimpleNamespace(is_external=True, target_ref="https://img/shape.png"),
|
||||
shape_int_id: SimpleNamespace(is_external=False, target_part=shape_internal_part),
|
||||
"link-ok": SimpleNamespace(is_external=True, target_ref="https://example.com"),
|
||||
}
|
||||
)
|
||||
|
||||
image_map = {
|
||||
ext_image_id: "[EXT]",
|
||||
internal_part: "[INT]",
|
||||
shape_ext_id: "[SHAPE_EXT]",
|
||||
shape_internal_part: "[SHAPE_INT]",
|
||||
}
|
||||
|
||||
class FakeBlip:
|
||||
def __init__(self, embed_id):
|
||||
self.embed_id = embed_id
|
||||
|
||||
def get(self, key):
|
||||
return self.embed_id
|
||||
|
||||
class FakeDrawing:
|
||||
def __init__(self, embed_ids):
|
||||
self.embed_ids = embed_ids
|
||||
|
||||
def findall(self, pattern):
|
||||
return [FakeBlip(embed_id) for embed_id in self.embed_ids]
|
||||
|
||||
class FakeNode:
|
||||
def __init__(self, text=None, attrs=None):
|
||||
self.text = text
|
||||
self._attrs = attrs or {}
|
||||
|
||||
def get(self, key):
|
||||
return self._attrs.get(key)
|
||||
|
||||
class FakeShape:
|
||||
def __init__(self, bin_id=None, img_id=None):
|
||||
self.bin_id = bin_id
|
||||
self.img_id = img_id
|
||||
|
||||
def find(self, pattern):
|
||||
if "binData" in pattern and self.bin_id:
|
||||
return FakeNode(
|
||||
text="shape",
|
||||
attrs={"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id": self.bin_id},
|
||||
)
|
||||
if "imagedata" in pattern and self.img_id:
|
||||
return FakeNode(attrs={"id": self.img_id})
|
||||
return None
|
||||
|
||||
class FakeChild:
|
||||
def __init__(
|
||||
self,
|
||||
tag,
|
||||
text="",
|
||||
fld_chars=None,
|
||||
instr_texts=None,
|
||||
drawings=None,
|
||||
shapes=None,
|
||||
attrs=None,
|
||||
hyperlink_runs=None,
|
||||
):
|
||||
self.tag = tag
|
||||
self.text = text
|
||||
self._fld_chars = fld_chars or []
|
||||
self._instr_texts = instr_texts or []
|
||||
self._drawings = drawings or []
|
||||
self._shapes = shapes or []
|
||||
self._attrs = attrs or {}
|
||||
self._hyperlink_runs = hyperlink_runs or []
|
||||
|
||||
def findall(self, pattern):
|
||||
if pattern == qn("w:fldChar"):
|
||||
return self._fld_chars
|
||||
if pattern == qn("w:instrText"):
|
||||
return self._instr_texts
|
||||
if pattern == qn("w:r"):
|
||||
return self._hyperlink_runs
|
||||
if pattern.endswith("}drawing"):
|
||||
return self._drawings
|
||||
if pattern.endswith("}pict"):
|
||||
return self._shapes
|
||||
return []
|
||||
|
||||
def get(self, key):
|
||||
return self._attrs.get(key)
|
||||
|
||||
class FakeRun:
|
||||
def __init__(self, element, paragraph):
|
||||
self.element = element
|
||||
self.text = getattr(element, "text", "")
|
||||
|
||||
paragraph_main = SimpleNamespace(
|
||||
_element=[
|
||||
FakeChild(
|
||||
qn("w:r"),
|
||||
text="run-text",
|
||||
drawings=[FakeDrawing([ext_image_id, int_embed_id])],
|
||||
shapes=[FakeShape(bin_id=shape_ext_id, img_id=shape_int_id)],
|
||||
),
|
||||
FakeChild(
|
||||
qn("w:r"),
|
||||
text="",
|
||||
drawings=[],
|
||||
shapes=[FakeShape(bin_id=shape_ext_id)],
|
||||
),
|
||||
FakeChild(
|
||||
qn("w:hyperlink"),
|
||||
attrs={qn("r:id"): "link-ok"},
|
||||
hyperlink_runs=[FakeChild(qn("w:r"), text="LinkText")],
|
||||
),
|
||||
FakeChild(
|
||||
qn("w:hyperlink"),
|
||||
attrs={qn("r:id"): "link-bad"},
|
||||
hyperlink_runs=[FakeChild(qn("w:r"), text="BrokenLink")],
|
||||
),
|
||||
]
|
||||
)
|
||||
paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")])
|
||||
|
||||
fake_doc = SimpleNamespace(
|
||||
part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}),
|
||||
paragraphs=[paragraph_main, paragraph_empty],
|
||||
tables=[SimpleNamespace(rows=[])],
|
||||
element=SimpleNamespace(
|
||||
body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")]
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc)
|
||||
monkeypatch.setattr(we, "Run", FakeRun)
|
||||
monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map)
|
||||
monkeypatch.setattr(extractor, "_table_to_markdown", lambda table, image_map: "TABLE-MARKDOWN")
|
||||
logger_exception = MagicMock()
|
||||
monkeypatch.setattr(we.logger, "exception", logger_exception)
|
||||
|
||||
content = extractor.parse_docx("dummy.docx")
|
||||
|
||||
assert "[EXT]" in content
|
||||
assert "[INT]" in content
|
||||
assert "[SHAPE_EXT]" in content
|
||||
assert "[LinkText](https://example.com)" in content
|
||||
assert "BrokenLink" in content
|
||||
assert "TABLE-MARKDOWN" in content
|
||||
logger_exception.assert_called_once()
|
||||
|
||||
@ -0,0 +1,300 @@
|
||||
"""Unit tests for unstructured extractors and their local/API partitioning paths."""
|
||||
|
||||
import base64
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.extractor.unstructured.unstructured_epub_extractor as epub_module
|
||||
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
|
||||
|
||||
|
||||
def _register_module(monkeypatch: pytest.MonkeyPatch, name: str, **attrs: object) -> types.ModuleType:
|
||||
module = types.ModuleType(name)
|
||||
for k, v in attrs.items():
|
||||
setattr(module, k, v)
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
return module
|
||||
|
||||
|
||||
def _register_unstructured_packages(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_register_module(monkeypatch, "unstructured", __path__=[])
|
||||
_register_module(monkeypatch, "unstructured.partition", __path__=[])
|
||||
_register_module(monkeypatch, "unstructured.chunking", __path__=[])
|
||||
_register_module(monkeypatch, "unstructured.file_utils", __path__=[])
|
||||
|
||||
|
||||
def _install_chunk_by_title(monkeypatch: pytest.MonkeyPatch, chunks: list[SimpleNamespace]) -> None:
|
||||
_register_unstructured_packages(monkeypatch)
|
||||
|
||||
def chunk_by_title(
|
||||
elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int
|
||||
) -> list[SimpleNamespace]:
|
||||
return chunks
|
||||
|
||||
_register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title)
|
||||
|
||||
|
||||
class TestUnstructuredMarkdownMsgXml:
|
||||
def test_markdown_extractor_without_api(self, monkeypatch):
|
||||
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" chunk-1 "), SimpleNamespace(text=" chunk-2 ")])
|
||||
_register_module(
|
||||
monkeypatch, "unstructured.partition.md", partition_md=lambda filename: [SimpleNamespace(text="x")]
|
||||
)
|
||||
|
||||
docs = UnstructuredMarkdownExtractor("/tmp/file.md").extract()
|
||||
|
||||
assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"]
|
||||
|
||||
def test_markdown_extractor_with_api(self, monkeypatch):
|
||||
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" via-api ")])
|
||||
calls = {}
|
||||
|
||||
def partition_via_api(filename, api_url, api_key):
|
||||
calls.update({"filename": filename, "api_url": api_url, "api_key": api_key})
|
||||
return [SimpleNamespace(text="ignored")]
|
||||
|
||||
_register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api)
|
||||
|
||||
docs = UnstructuredMarkdownExtractor("/tmp/file.md", api_url="https://u", api_key="k").extract()
|
||||
|
||||
assert docs[0].page_content == "via-api"
|
||||
assert calls == {"filename": "/tmp/file.md", "api_url": "https://u", "api_key": "k"}
|
||||
|
||||
def test_msg_extractor_local(self, monkeypatch):
|
||||
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")])
|
||||
_register_module(
|
||||
monkeypatch, "unstructured.partition.msg", partition_msg=lambda filename: [SimpleNamespace(text="x")]
|
||||
)
|
||||
|
||||
assert UnstructuredMsgExtractor("/tmp/file.msg").extract()[0].page_content == "msg-doc"
|
||||
|
||||
def test_msg_extractor_with_api(self, monkeypatch):
|
||||
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")])
|
||||
calls = {}
|
||||
|
||||
def partition_via_api(filename, api_url, api_key):
|
||||
calls.update({"filename": filename, "api_url": api_url, "api_key": api_key})
|
||||
return [SimpleNamespace(text="x")]
|
||||
|
||||
_register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api)
|
||||
|
||||
assert (
|
||||
UnstructuredMsgExtractor("/tmp/file.msg", api_url="https://u", api_key="k").extract()[0].page_content
|
||||
== "msg-doc"
|
||||
)
|
||||
assert calls["filename"] == "/tmp/file.msg"
|
||||
|
||||
def test_xml_extractor_local_and_api(self, monkeypatch):
|
||||
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="xml-doc")])
|
||||
|
||||
xml_calls = {}
|
||||
|
||||
def partition_xml(filename, xml_keep_tags):
|
||||
xml_calls.update({"filename": filename, "xml_keep_tags": xml_keep_tags})
|
||||
return [SimpleNamespace(text="x")]
|
||||
|
||||
_register_module(monkeypatch, "unstructured.partition.xml", partition_xml=partition_xml)
|
||||
|
||||
assert UnstructuredXmlExtractor("/tmp/file.xml").extract()[0].page_content == "xml-doc"
|
||||
assert xml_calls == {"filename": "/tmp/file.xml", "xml_keep_tags": True}
|
||||
|
||||
api_calls = {}
|
||||
|
||||
def partition_via_api(filename, api_url, api_key):
|
||||
api_calls.update({"filename": filename, "api_url": api_url, "api_key": api_key})
|
||||
return [SimpleNamespace(text="x")]
|
||||
|
||||
_register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api)
|
||||
|
||||
assert (
|
||||
UnstructuredXmlExtractor("/tmp/file.xml", api_url="https://u", api_key="k").extract()[0].page_content
|
||||
== "xml-doc"
|
||||
)
|
||||
assert api_calls["filename"] == "/tmp/file.xml"
|
||||
|
||||
|
||||
class TestUnstructuredEmailAndEpub:
|
||||
def test_email_extractor_local_decodes_html_and_suppresses_decode_errors(self, monkeypatch):
|
||||
_register_unstructured_packages(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def chunk_by_title(
|
||||
elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int
|
||||
) -> list[SimpleNamespace]:
|
||||
captured["elements"] = list(elements)
|
||||
return [SimpleNamespace(text=" chunked-email ")]
|
||||
|
||||
_register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title)
|
||||
|
||||
html = "<p>Hello Email</p>"
|
||||
encoded_html = base64.b64encode(html.encode("utf-8")).decode("utf-8")
|
||||
bad_base64 = "not-base64"
|
||||
|
||||
elements = [SimpleNamespace(text=encoded_html), SimpleNamespace(text=bad_base64)]
|
||||
_register_module(monkeypatch, "unstructured.partition.email", partition_email=lambda filename: elements)
|
||||
|
||||
docs = UnstructuredEmailExtractor("/tmp/file.eml").extract()
|
||||
|
||||
assert docs[0].page_content == "chunked-email"
|
||||
chunk_elements = captured["elements"]
|
||||
assert "Hello Email" in chunk_elements[0].text
|
||||
assert chunk_elements[1].text == bad_base64
|
||||
|
||||
def test_email_extractor_with_api(self, monkeypatch):
|
||||
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="api-email")])
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.partition.api",
|
||||
partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="abc")],
|
||||
)
|
||||
|
||||
docs = UnstructuredEmailExtractor("/tmp/file.eml", api_url="https://u", api_key="k").extract()
|
||||
|
||||
assert docs[0].page_content == "api-email"
|
||||
|
||||
def test_epub_extractor_local_and_api(self, monkeypatch):
|
||||
_install_chunk_by_title(monkeypatch, [SimpleNamespace(text="epub-doc")])
|
||||
|
||||
calls = {"download": 0, "partition": 0}
|
||||
|
||||
def fake_download_pandoc():
|
||||
calls["download"] += 1
|
||||
|
||||
def partition_epub(filename, xml_keep_tags):
|
||||
calls["partition"] += 1
|
||||
assert xml_keep_tags is True
|
||||
return [SimpleNamespace(text="x")]
|
||||
|
||||
monkeypatch.setattr(epub_module.pypandoc, "download_pandoc", fake_download_pandoc)
|
||||
_register_module(monkeypatch, "unstructured.partition.epub", partition_epub=partition_epub)
|
||||
|
||||
docs = UnstructuredEpubExtractor("/tmp/file.epub").extract()
|
||||
|
||||
assert docs[0].page_content == "epub-doc"
|
||||
assert calls == {"download": 1, "partition": 1}
|
||||
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.partition.api",
|
||||
partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="x")],
|
||||
)
|
||||
|
||||
docs = UnstructuredEpubExtractor("/tmp/file.epub", api_url="https://u", api_key="k").extract()
|
||||
assert docs[0].page_content == "epub-doc"
|
||||
|
||||
|
||||
class TestUnstructuredPPTAndPPTX:
|
||||
def test_ppt_extractor_requires_api_url(self):
|
||||
with pytest.raises(NotImplementedError, match="Unstructured API Url is not configured"):
|
||||
UnstructuredPPTExtractor("/tmp/file.ppt").extract()
|
||||
|
||||
def test_ppt_extractor_groups_text_by_page(self, monkeypatch):
|
||||
_register_unstructured_packages(monkeypatch)
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.partition.api",
|
||||
partition_via_api=lambda filename, api_url, api_key: [
|
||||
SimpleNamespace(text="A", metadata=SimpleNamespace(page_number=1)),
|
||||
SimpleNamespace(text="B", metadata=SimpleNamespace(page_number=1)),
|
||||
SimpleNamespace(text="skip", metadata=SimpleNamespace(page_number=None)),
|
||||
SimpleNamespace(text="C", metadata=SimpleNamespace(page_number=2)),
|
||||
],
|
||||
)
|
||||
|
||||
docs = UnstructuredPPTExtractor("/tmp/file.ppt", api_url="https://u", api_key="k").extract()
|
||||
|
||||
assert [doc.page_content for doc in docs] == ["A\nB", "C"]
|
||||
|
||||
def test_pptx_extractor_local_and_api(self, monkeypatch):
|
||||
_register_unstructured_packages(monkeypatch)
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.partition.pptx",
|
||||
partition_pptx=lambda filename: [
|
||||
SimpleNamespace(text="P1", metadata=SimpleNamespace(page_number=1)),
|
||||
SimpleNamespace(text="P2", metadata=SimpleNamespace(page_number=2)),
|
||||
SimpleNamespace(text="Skip", metadata=SimpleNamespace(page_number=None)),
|
||||
],
|
||||
)
|
||||
|
||||
docs = UnstructuredPPTXExtractor("/tmp/file.pptx").extract()
|
||||
assert [doc.page_content for doc in docs] == ["P1", "P2"]
|
||||
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.partition.api",
|
||||
partition_via_api=lambda filename, api_url, api_key: [
|
||||
SimpleNamespace(text="X", metadata=SimpleNamespace(page_number=1)),
|
||||
SimpleNamespace(text="Y", metadata=SimpleNamespace(page_number=1)),
|
||||
],
|
||||
)
|
||||
|
||||
docs = UnstructuredPPTXExtractor("/tmp/file.pptx", api_url="https://u", api_key="k").extract()
|
||||
assert [doc.page_content for doc in docs] == ["X\nY"]
|
||||
|
||||
|
||||
class TestUnstructuredWord:
|
||||
def _install_doc_modules(self, monkeypatch, version: str, filetype_value):
|
||||
_register_unstructured_packages(monkeypatch)
|
||||
|
||||
class FileType:
|
||||
DOC = "doc"
|
||||
|
||||
_register_module(monkeypatch, "unstructured.__version__", __version__=version)
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.file_utils.filetype",
|
||||
FileType=FileType,
|
||||
detect_filetype=lambda filename: filetype_value,
|
||||
)
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.partition.api",
|
||||
partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="api-doc")],
|
||||
)
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.partition.docx",
|
||||
partition_docx=lambda filename: [SimpleNamespace(text="docx-doc")],
|
||||
)
|
||||
_register_module(
|
||||
monkeypatch,
|
||||
"unstructured.chunking.title",
|
||||
chunk_by_title=lambda elements, max_characters, combine_text_under_n_chars: [
|
||||
SimpleNamespace(text="chunk-1"),
|
||||
SimpleNamespace(text="chunk-2"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_word_extractor_rejects_doc_on_old_unstructured_version(self, monkeypatch):
|
||||
self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="doc")
|
||||
|
||||
with pytest.raises(ValueError, match="Partitioning .doc files is only supported"):
|
||||
UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract()
|
||||
|
||||
def test_word_extractor_doc_and_docx_paths(self, monkeypatch):
|
||||
self._install_doc_modules(monkeypatch, version="0.4.11", filetype_value="doc")
|
||||
|
||||
docs = UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract()
|
||||
assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"]
|
||||
|
||||
self._install_doc_modules(monkeypatch, version="0.5.0", filetype_value="not-doc")
|
||||
docs = UnstructuredWordExtractor("/tmp/file.docx", "https://u", "k").extract()
|
||||
assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"]
|
||||
|
||||
def test_word_extractor_magic_import_error_fallback_to_extension(self, monkeypatch):
|
||||
self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="not-used")
|
||||
monkeypatch.setitem(sys.modules, "magic", None)
|
||||
|
||||
with pytest.raises(ValueError, match="Partitioning .doc files is only supported"):
|
||||
UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract()
|
||||
@ -0,0 +1,434 @@
|
||||
"""Unit tests for WaterCrawl client, provider, and extractor behavior."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.extractor.watercrawl.client as client_module
|
||||
from core.rag.extractor.watercrawl.client import BaseAPIClient, WaterCrawlAPIClient
|
||||
from core.rag.extractor.watercrawl.exceptions import (
|
||||
WaterCrawlAuthenticationError,
|
||||
WaterCrawlBadRequestError,
|
||||
WaterCrawlPermissionError,
|
||||
)
|
||||
from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor
|
||||
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
|
||||
|
||||
|
||||
def _response(
|
||||
status_code: int,
|
||||
json_data: dict[str, Any] | None = None,
|
||||
content_type: str = "application/json",
|
||||
content: bytes = b"",
|
||||
text: str = "",
|
||||
) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
response.headers = {"Content-Type": content_type}
|
||||
response.content = content
|
||||
response.text = text
|
||||
response.json.return_value = json_data if json_data is not None else {}
|
||||
response.raise_for_status.return_value = None
|
||||
response.close.return_value = None
|
||||
return response
|
||||
|
||||
|
||||
class TestWaterCrawlExceptions:
|
||||
def test_bad_request_error_properties_and_string(self):
|
||||
response = _response(400, {"message": "bad request", "errors": {"url": ["invalid"]}})
|
||||
|
||||
err = WaterCrawlBadRequestError(response)
|
||||
parsed_errors = json.loads(err.flat_errors)
|
||||
|
||||
assert err.status_code == 400
|
||||
assert err.message == "bad request"
|
||||
assert "url" in parsed_errors
|
||||
assert any("invalid" in str(item) for item in parsed_errors["url"])
|
||||
assert "WaterCrawlBadRequestError" in str(err)
|
||||
|
||||
def test_permission_and_authentication_error_strings(self):
|
||||
response = _response(403, {"message": "quota exceeded", "errors": {}})
|
||||
|
||||
permission = WaterCrawlPermissionError(response)
|
||||
authentication = WaterCrawlAuthenticationError(response)
|
||||
|
||||
assert "exceeding your WaterCrawl API limits" in str(permission)
|
||||
assert "API key is invalid or expired" in str(authentication)
|
||||
|
||||
|
||||
class TestBaseAPIClient:
|
||||
def test_init_session_builds_expected_headers(self, monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_client(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return "session"
|
||||
|
||||
monkeypatch.setattr(client_module.httpx, "Client", fake_client)
|
||||
|
||||
client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev")
|
||||
|
||||
assert client.session == "session"
|
||||
assert captured["headers"]["X-API-Key"] == "k"
|
||||
assert captured["headers"]["User-Agent"] == "WaterCrawl-Plugin"
|
||||
|
||||
def test_request_stream_and_non_stream_paths(self, monkeypatch):
|
||||
class FakeSession:
|
||||
def __init__(self):
|
||||
self.request_calls = []
|
||||
self.build_calls = []
|
||||
self.send_calls = []
|
||||
|
||||
def request(self, method, url, params=None, json=None, **kwargs):
|
||||
self.request_calls.append((method, url, params, json, kwargs))
|
||||
return "non-stream-response"
|
||||
|
||||
def build_request(self, method, url, params=None, json=None):
|
||||
req = (method, url, params, json)
|
||||
self.build_calls.append(req)
|
||||
return req
|
||||
|
||||
def send(self, request, stream=False, **kwargs):
|
||||
self.send_calls.append((request, stream, kwargs))
|
||||
return "stream-response"
|
||||
|
||||
fake_session = FakeSession()
|
||||
monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: fake_session)
|
||||
|
||||
client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev")
|
||||
|
||||
assert client._request("GET", "/v1/items", query_params={"a": 1}) == "non-stream-response"
|
||||
assert fake_session.request_calls[0][1] == "https://watercrawl.dev/v1/items"
|
||||
|
||||
assert client._request("GET", "/v1/items", stream=True) == "stream-response"
|
||||
assert fake_session.build_calls
|
||||
assert fake_session.send_calls[0][1] is True
|
||||
|
||||
def test_http_method_helpers_delegate_to_request(self, monkeypatch):
|
||||
monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: MagicMock())
|
||||
client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev")
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_request(method, endpoint, query_params=None, data=None, **kwargs):
|
||||
calls.append((method, endpoint, query_params, data))
|
||||
return "ok"
|
||||
|
||||
monkeypatch.setattr(client, "_request", fake_request)
|
||||
|
||||
assert client._get("/a") == "ok"
|
||||
assert client._post("/b", data={"x": 1}) == "ok"
|
||||
assert client._put("/c", data={"x": 2}) == "ok"
|
||||
assert client._delete("/d") == "ok"
|
||||
assert client._patch("/e", data={"x": 3}) == "ok"
|
||||
assert [c[0] for c in calls] == ["GET", "POST", "PUT", "DELETE", "PATCH"]
|
||||
|
||||
|
||||
class TestWaterCrawlAPIClient:
|
||||
def test_process_eventstream_and_download(self, monkeypatch):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
|
||||
response = MagicMock()
|
||||
response.iter_lines.return_value = [
|
||||
b"event: keep-alive",
|
||||
b'data: {"type":"result","data":{"result":"http://x"}}',
|
||||
b'data: {"type":"log","data":{"msg":"ok"}}',
|
||||
]
|
||||
|
||||
monkeypatch.setattr(client, "download_result", lambda data: {"result": {"markdown": "body"}, "url": "u"})
|
||||
|
||||
events = list(client.process_eventstream(response, download=True))
|
||||
|
||||
assert events[0]["data"]["result"]["markdown"] == "body"
|
||||
assert events[1]["type"] == "log"
|
||||
response.close.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status", "expected_exception"),
|
||||
[
|
||||
(401, WaterCrawlAuthenticationError),
|
||||
(403, WaterCrawlPermissionError),
|
||||
(422, WaterCrawlBadRequestError),
|
||||
],
|
||||
)
|
||||
def test_process_response_error_statuses(self, status: int, expected_exception: type[Exception]):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
|
||||
with pytest.raises(expected_exception):
|
||||
client.process_response(_response(status, {"message": "bad", "errors": {"url": ["x"]}}))
|
||||
|
||||
def test_process_response_204_returns_none(self):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
assert client.process_response(_response(204, None)) is None
|
||||
|
||||
def test_process_response_json_payloads(self):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
assert client.process_response(_response(200, {"ok": True})) == {"ok": True}
|
||||
assert client.process_response(_response(200, None)) == {}
|
||||
|
||||
def test_process_response_octet_stream_returns_bytes(self):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
assert (
|
||||
client.process_response(_response(200, content_type="application/octet-stream", content=b"bin")) == b"bin"
|
||||
)
|
||||
|
||||
def test_process_response_event_stream_returns_generator(self, monkeypatch):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
generator = (item for item in [{"type": "result", "data": {}}])
|
||||
monkeypatch.setattr(client, "process_eventstream", lambda response, download=False: generator)
|
||||
assert client.process_response(_response(200, content_type="text/event-stream")) is generator
|
||||
|
||||
def test_process_response_unknown_content_type_raises(self):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
with pytest.raises(Exception, match="Unknown response type"):
|
||||
client.process_response(_response(200, content_type="text/plain", text="x"))
|
||||
|
||||
def test_process_response_uses_raise_for_status(self):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
response = _response(500, {"message": "server"})
|
||||
response.raise_for_status.side_effect = RuntimeError("http error")
|
||||
|
||||
with pytest.raises(RuntimeError, match="http error"):
|
||||
client.process_response(response)
|
||||
|
||||
def test_endpoint_wrappers(self, monkeypatch):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
|
||||
monkeypatch.setattr(client, "process_response", lambda resp: "processed")
|
||||
monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "get-resp")
|
||||
monkeypatch.setattr(client, "_post", lambda *args, **kwargs: "post-resp")
|
||||
monkeypatch.setattr(client, "_delete", lambda *args, **kwargs: "delete-resp")
|
||||
|
||||
assert client.get_crawl_requests_list() == "processed"
|
||||
assert client.get_crawl_request("id") == "processed"
|
||||
assert client.create_crawl_request(url="https://x") == "processed"
|
||||
assert client.stop_crawl_request("id") == "processed"
|
||||
assert client.download_crawl_request("id") == "processed"
|
||||
assert client.get_crawl_request_results("id") == "processed"
|
||||
|
||||
def test_monitor_crawl_request_generator_and_validation(self, monkeypatch):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
|
||||
monkeypatch.setattr(client, "process_response", lambda _: (x for x in [{"type": "result", "data": 1}]))
|
||||
monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "stream-resp")
|
||||
|
||||
events = list(client.monitor_crawl_request("job-1", prefetched=True))
|
||||
assert events == [{"type": "result", "data": 1}]
|
||||
|
||||
monkeypatch.setattr(client, "process_response", lambda _: [{"type": "result"}])
|
||||
with pytest.raises(ValueError, match="Generator expected"):
|
||||
list(client.monitor_crawl_request("job-1"))
|
||||
|
||||
def test_scrape_url_sync_and_async(self, monkeypatch):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
monkeypatch.setattr(client, "create_crawl_request", lambda **kwargs: {"uuid": "job-1"})
|
||||
|
||||
async_result = client.scrape_url("https://example.com", sync=False)
|
||||
assert async_result == {"uuid": "job-1"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
client,
|
||||
"monitor_crawl_request",
|
||||
lambda item_id, prefetched: iter(
|
||||
[{"type": "log", "data": {}}, {"type": "result", "data": {"url": "https://example.com"}}]
|
||||
),
|
||||
)
|
||||
sync_result = client.scrape_url("https://example.com", sync=True)
|
||||
assert sync_result == {"url": "https://example.com"}
|
||||
|
||||
def test_download_result_fetches_json_and_closes(self, monkeypatch):
|
||||
client = WaterCrawlAPIClient(api_key="k")
|
||||
|
||||
response = _response(200, {"markdown": "body"})
|
||||
monkeypatch.setattr(client_module.httpx, "get", lambda *args, **kwargs: response)
|
||||
|
||||
result = client.download_result({"result": "https://example.com/result.json"})
|
||||
|
||||
assert result["result"] == {"markdown": "body"}
|
||||
response.close.assert_called_once()
|
||||
|
||||
|
||||
class TestWaterCrawlProvider:
|
||||
def test_crawl_url_builds_options_and_min_wait_time(self, monkeypatch):
|
||||
provider = WaterCrawlProvider(api_key="k")
|
||||
captured_kwargs = {}
|
||||
|
||||
def create_crawl_request_spy(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return {"uuid": "job-1"}
|
||||
|
||||
monkeypatch.setattr(provider.client, "create_crawl_request", create_crawl_request_spy)
|
||||
|
||||
result = provider.crawl_url(
|
||||
"https://example.com",
|
||||
{
|
||||
"crawl_sub_pages": True,
|
||||
"limit": 5,
|
||||
"max_depth": 2,
|
||||
"includes": "a,b",
|
||||
"excludes": "x,y",
|
||||
"exclude_tags": "nav,footer",
|
||||
"include_tags": "main",
|
||||
"wait_time": 100,
|
||||
"only_main_content": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"status": "active", "job_id": "job-1"}
|
||||
assert captured_kwargs["url"] == "https://example.com"
|
||||
assert captured_kwargs["spider_options"] == {
|
||||
"max_depth": 2,
|
||||
"page_limit": 5,
|
||||
"allowed_domains": [],
|
||||
"exclude_paths": ["x", "y"],
|
||||
"include_paths": ["a", "b"],
|
||||
}
|
||||
assert captured_kwargs["page_options"]["exclude_tags"] == ["nav", "footer"]
|
||||
assert captured_kwargs["page_options"]["include_tags"] == ["main"]
|
||||
assert captured_kwargs["page_options"]["only_main_content"] is False
|
||||
assert captured_kwargs["page_options"]["wait_time"] == 1000
|
||||
|
||||
def test_get_crawl_status_active_and_completed(self, monkeypatch):
|
||||
provider = WaterCrawlProvider(api_key="k")
|
||||
|
||||
monkeypatch.setattr(
|
||||
provider.client,
|
||||
"get_crawl_request",
|
||||
lambda job_id: {
|
||||
"status": "running",
|
||||
"uuid": job_id,
|
||||
"options": {"spider_options": {"page_limit": 3}},
|
||||
"number_of_documents": 1,
|
||||
"duration": "00:00:01.500000",
|
||||
},
|
||||
)
|
||||
|
||||
active = provider.get_crawl_status("job-1")
|
||||
assert active["status"] == "active"
|
||||
assert active["data"] == []
|
||||
assert active["time_consuming"] == pytest.approx(1.5)
|
||||
|
||||
monkeypatch.setattr(
|
||||
provider.client,
|
||||
"get_crawl_request",
|
||||
lambda job_id: {
|
||||
"status": "completed",
|
||||
"uuid": job_id,
|
||||
"options": {"spider_options": {"page_limit": 2}},
|
||||
"number_of_documents": 2,
|
||||
"duration": "00:00:02.000000",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(provider, "_get_results", lambda crawl_request_id, query_params=None: iter([{"url": "u"}]))
|
||||
|
||||
completed = provider.get_crawl_status("job-2")
|
||||
assert completed["status"] == "completed"
|
||||
assert completed["data"] == [{"url": "u"}]
|
||||
|
||||
def test_get_crawl_url_data_and_scrape(self, monkeypatch):
|
||||
provider = WaterCrawlProvider(api_key="k")
|
||||
|
||||
monkeypatch.setattr(provider, "scrape_url", lambda url: {"source_url": url})
|
||||
assert provider.get_crawl_url_data("", "https://example.com") == {"source_url": "https://example.com"}
|
||||
|
||||
monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([{"source_url": "u1"}]))
|
||||
assert provider.get_crawl_url_data("job", "u1") == {"source_url": "u1"}
|
||||
|
||||
monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([]))
|
||||
assert provider.get_crawl_url_data("job", "u1") is None
|
||||
|
||||
def test_structure_data_validation_and_get_results_pagination(self, monkeypatch):
|
||||
provider = WaterCrawlProvider(api_key="k")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid result object"):
|
||||
provider._structure_data({"result": "not-a-dict"})
|
||||
|
||||
structured = provider._structure_data(
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"result": {
|
||||
"metadata": {"title": "Title", "description": "Desc"},
|
||||
"markdown": "Body",
|
||||
},
|
||||
}
|
||||
)
|
||||
assert structured["title"] == "Title"
|
||||
assert structured["markdown"] == "Body"
|
||||
|
||||
responses = [
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"url": "https://a",
|
||||
"result": {"metadata": {"title": "A", "description": "DA"}, "markdown": "MA"},
|
||||
}
|
||||
],
|
||||
"next": "next-page",
|
||||
},
|
||||
{"results": [], "next": None},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
provider.client,
|
||||
"get_crawl_request_results",
|
||||
lambda crawl_request_id, page, page_size, query_params: responses.pop(0),
|
||||
)
|
||||
|
||||
results = list(provider._get_results("job-1"))
|
||||
assert len(results) == 1
|
||||
assert results[0]["source_url"] == "https://a"
|
||||
|
||||
def test_scrape_url_uses_client_and_structure(self, monkeypatch):
|
||||
provider = WaterCrawlProvider(api_key="k")
|
||||
monkeypatch.setattr(
|
||||
provider.client, "scrape_url", lambda **kwargs: {"result": {"metadata": {}, "markdown": "m"}, "url": "u"}
|
||||
)
|
||||
|
||||
result = provider.scrape_url("u")
|
||||
|
||||
assert result["source_url"] == "u"
|
||||
|
||||
|
||||
class TestWaterCrawlWebExtractor:
|
||||
def test_extract_crawl_and_scrape_modes(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data",
|
||||
lambda job_id, provider, url, tenant_id: {
|
||||
"markdown": "crawl",
|
||||
"source_url": url,
|
||||
"description": "d",
|
||||
"title": "t",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.rag.extractor.watercrawl.extractor.WebsiteService.get_scrape_url_data",
|
||||
lambda provider, url, tenant_id, only_main_content: {
|
||||
"markdown": "scrape",
|
||||
"source_url": url,
|
||||
"description": "d",
|
||||
"title": "t",
|
||||
},
|
||||
)
|
||||
|
||||
crawl_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
|
||||
scrape_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape")
|
||||
|
||||
assert crawl_extractor.extract()[0].page_content == "crawl"
|
||||
assert scrape_extractor.extract()[0].page_content == "scrape"
|
||||
|
||||
def test_extract_crawl_returns_empty_when_service_returns_none(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data",
|
||||
lambda job_id, provider, url, tenant_id: None,
|
||||
)
|
||||
|
||||
extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl")
|
||||
|
||||
assert extractor.extract() == []
|
||||
|
||||
def test_extract_unknown_mode_returns_empty(self):
|
||||
extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="other")
|
||||
|
||||
assert extractor.extract() == []
|
||||
33
api/tests/unit_tests/core/rag/indexing/processor/conftest.py
Normal file
33
api/tests/unit_tests/core/rag/indexing/processor/conftest.py
Normal file
@ -0,0 +1,33 @@
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeFlaskApp:
|
||||
def app_context(self) -> AbstractContextManager[None]:
|
||||
return nullcontext()
|
||||
|
||||
|
||||
class _FakeExecutor:
|
||||
def __init__(self, future: Any) -> None:
|
||||
self._future = future
|
||||
|
||||
def __enter__(self) -> "_FakeExecutor":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> bool:
|
||||
return False
|
||||
|
||||
def submit(self, func: object, preview: object) -> Any:
|
||||
return self._future
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_flask_app() -> _FakeFlaskApp:
|
||||
return _FakeFlaskApp()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_executor_cls() -> type[_FakeExecutor]:
|
||||
return _FakeExecutor
|
||||
@ -0,0 +1,629 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
|
||||
class TestParagraphIndexProcessor:
|
||||
@pytest.fixture
|
||||
def processor(self) -> ParagraphIndexProcessor:
|
||||
return ParagraphIndexProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(self) -> Mock:
|
||||
dataset = Mock()
|
||||
dataset.id = "dataset-1"
|
||||
dataset.tenant_id = "tenant-1"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.is_multimodal = True
|
||||
return dataset
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_document(self) -> Mock:
|
||||
document = Mock()
|
||||
document.id = "doc-1"
|
||||
document.created_by = "user-1"
|
||||
return document
|
||||
|
||||
@pytest.fixture
|
||||
def process_rule(self) -> dict:
|
||||
return {
|
||||
"mode": "custom",
|
||||
"rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}},
|
||||
}
|
||||
|
||||
def _rules(self) -> SimpleNamespace:
|
||||
segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n")
|
||||
return SimpleNamespace(segmentation=segmentation)
|
||||
|
||||
def _llm_result(self, content: str = "summary") -> LLMResult:
|
||||
return LLMResult(
|
||||
model="llm-model",
|
||||
message=AssistantPromptMessage(content=content),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
def test_extract_forwards_automatic_flag(self, processor: ParagraphIndexProcessor) -> None:
|
||||
extract_setting = Mock()
|
||||
expected_docs = [Document(page_content="chunk", metadata={})]
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ExtractProcessor.extract"
|
||||
) as mock_extract:
|
||||
mock_extract.return_value = expected_docs
|
||||
docs = processor.extract(extract_setting, process_rule_mode="hierarchical")
|
||||
|
||||
assert docs == expected_docs
|
||||
mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True)
|
||||
|
||||
def test_transform_validates_process_rule(self, processor: ParagraphIndexProcessor) -> None:
|
||||
with pytest.raises(ValueError, match="No process rule found"):
|
||||
processor.transform([Document(page_content="text", metadata={})], process_rule=None)
|
||||
|
||||
with pytest.raises(ValueError, match="No rules found in process rule"):
|
||||
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"})
|
||||
|
||||
def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None:
|
||||
rules_without_segmentation = SimpleNamespace(segmentation=None)
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate",
|
||||
return_value=rules_without_segmentation,
|
||||
):
|
||||
with pytest.raises(ValueError, match="No segmentation found in rules"):
|
||||
processor.transform(
|
||||
[Document(page_content="text", metadata={})],
|
||||
process_rule={"mode": "custom", "rules": {"enabled": True}},
|
||||
)
|
||||
|
||||
def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None:
|
||||
source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"})
|
||||
splitter = Mock()
|
||||
splitter.split_documents.return_value = [
|
||||
Document(page_content=".first", metadata={}),
|
||||
Document(page_content=" ", metadata={}),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate",
|
||||
return_value=self._rules(),
|
||||
),
|
||||
patch.object(processor, "_get_splitter", return_value=splitter),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean",
|
||||
return_value=".first",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols",
|
||||
side_effect=lambda text: text.lstrip("."),
|
||||
),
|
||||
patch.object(
|
||||
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})]
|
||||
),
|
||||
):
|
||||
documents = processor.transform([source_document], process_rule=process_rule)
|
||||
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "first"
|
||||
assert documents[0].attachments is not None
|
||||
assert documents[0].metadata["doc_hash"] == "hash"
|
||||
|
||||
def test_transform_automatic_mode_uses_default_rules(self, processor: ParagraphIndexProcessor) -> None:
|
||||
splitter = Mock()
|
||||
splitter.split_documents.return_value = [Document(page_content="text", metadata={})]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate",
|
||||
return_value=self._rules(),
|
||||
) as mock_validate,
|
||||
patch.object(processor, "_get_splitter", return_value=splitter),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean",
|
||||
side_effect=lambda text, _: text,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols",
|
||||
side_effect=lambda text: text,
|
||||
),
|
||||
patch.object(processor, "_get_content_files", return_value=[]),
|
||||
):
|
||||
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "automatic"})
|
||||
|
||||
assert mock_validate.call_count == 1
|
||||
|
||||
def test_load_creates_vector_and_multimodal_when_high_quality(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
docs = [Document(page_content="chunk", metadata={})]
|
||||
multimodal_docs = [AttachmentDocument(page_content="image", metadata={})]
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls,
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls,
|
||||
):
|
||||
processor.load(dataset, docs, multimodal_documents=multimodal_docs)
|
||||
vector = mock_vector_cls.return_value
|
||||
vector.create.assert_called_once_with(docs)
|
||||
vector.create_multimodal.assert_called_once_with(multimodal_docs)
|
||||
mock_keyword_cls.assert_not_called()
|
||||
|
||||
def test_load_uses_keyword_add_texts_with_keywords_when_economy(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
docs = [Document(page_content="chunk", metadata={})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
|
||||
processor.load(dataset, docs, keywords_list=["k1", "k2"])
|
||||
|
||||
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs, keywords_list=["k1", "k2"])
|
||||
|
||||
def test_load_uses_keyword_add_texts_without_keywords_when_economy(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
docs = [Document(page_content="chunk", metadata={})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
|
||||
processor.load(dataset, docs)
|
||||
|
||||
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs)
|
||||
|
||||
def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
|
||||
segment_query = Mock()
|
||||
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
session = Mock()
|
||||
session.query.return_value = segment_query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments"
|
||||
) as mock_summary,
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls,
|
||||
):
|
||||
vector = mock_vector_cls.return_value
|
||||
processor.clean(dataset, ["node-1"], delete_summaries=True)
|
||||
|
||||
mock_summary.assert_called_once_with(dataset, ["seg-1"])
|
||||
vector.delete_by_ids.assert_called_once_with(["node-1"])
|
||||
|
||||
def test_clean_economy_deletes_summaries_and_keywords(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments"
|
||||
) as mock_summary,
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls,
|
||||
):
|
||||
processor.clean(dataset, None, delete_summaries=True)
|
||||
|
||||
mock_summary.assert_called_once_with(dataset, None)
|
||||
mock_keyword_cls.return_value.delete.assert_called_once()
|
||||
|
||||
def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls:
|
||||
processor.clean(dataset, ["node-2"], with_keywords=True)
|
||||
|
||||
mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"])
|
||||
|
||||
def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
|
||||
accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9)
|
||||
rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1)
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve"
|
||||
) as mock_retrieve:
|
||||
mock_retrieve.return_value = [accepted, rejected]
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
|
||||
def test_index_list_chunks_high_quality(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch.object(
|
||||
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})]
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"
|
||||
) as mock_store_cls,
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls,
|
||||
):
|
||||
processor.index(dataset, dataset_document, ["chunk-1", "chunk-2"])
|
||||
|
||||
mock_store_cls.return_value.add_documents.assert_called_once()
|
||||
mock_vector_cls.return_value.create.assert_called_once()
|
||||
mock_vector_cls.return_value.create_multimodal.assert_called_once()
|
||||
|
||||
def test_index_list_chunks_economy(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch.object(processor, "_get_content_files", return_value=[]),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls,
|
||||
):
|
||||
processor.index(dataset, dataset_document, ["chunk-3"])
|
||||
|
||||
mock_keyword_cls.return_value.add_texts.assert_called_once()
|
||||
|
||||
def test_index_multimodal_structure_handles_files_and_account_lookup(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
chunk_with_files = SimpleNamespace(
|
||||
content="content-1",
|
||||
files=[SimpleNamespace(id="file-1", filename="image.png")],
|
||||
)
|
||||
chunk_without_files = SimpleNamespace(content="content-2", files=None)
|
||||
structure = SimpleNamespace(general_chunks=[chunk_with_files, chunk_without_files])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate",
|
||||
return_value=structure,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user",
|
||||
return_value=SimpleNamespace(id="user-1"),
|
||||
),
|
||||
patch.object(
|
||||
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})]
|
||||
) as mock_files,
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.Vector"),
|
||||
):
|
||||
processor.index(dataset, dataset_document, {"general_chunks": []})
|
||||
|
||||
assert mock_files.call_count == 1
|
||||
|
||||
def test_index_multimodal_structure_requires_valid_account(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
structure = SimpleNamespace(general_chunks=[SimpleNamespace(content="content", files=None)])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate",
|
||||
return_value=structure,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid account"):
|
||||
processor.index(dataset, dataset_document, {"general_chunks": []})
|
||||
|
||||
def test_format_preview_validates_chunk_shape(self, processor: ParagraphIndexProcessor) -> None:
|
||||
preview = processor.format_preview(["chunk-1", "chunk-2"])
|
||||
assert preview["chunk_structure"] == "text_model"
|
||||
assert preview["total_segments"] == 2
|
||||
|
||||
with pytest.raises(ValueError, match="Chunks is not a list"):
|
||||
processor.format_preview({"not": "a-list"})
|
||||
|
||||
def test_generate_summary_preview_success_and_failure(self, processor: ParagraphIndexProcessor) -> None:
|
||||
preview_items = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")]
|
||||
|
||||
with patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())):
|
||||
result = processor.generate_summary_preview(
|
||||
"tenant-1", preview_items, {"enable": True}, doc_language="English"
|
||||
)
|
||||
assert all(item.summary == "summary" for item in result)
|
||||
|
||||
with patch.object(processor, "generate_summary", side_effect=RuntimeError("summary failed")):
|
||||
with pytest.raises(ValueError, match="Failed to generate summaries"):
|
||||
processor.generate_summary_preview("tenant-1", [PreviewDetail(content="chunk-1")], {"enable": True})
|
||||
|
||||
def test_generate_summary_preview_fallback_without_flask_context(self, processor: ParagraphIndexProcessor) -> None:
|
||||
preview_items = [PreviewDetail(content="chunk-1")]
|
||||
fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app")))
|
||||
|
||||
with (
|
||||
patch("flask.current_app", fake_current_app),
|
||||
patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())),
|
||||
):
|
||||
result = processor.generate_summary_preview("tenant-1", preview_items, {"enable": True})
|
||||
|
||||
assert result[0].summary == "summary"
|
||||
|
||||
def test_generate_summary_preview_timeout(
|
||||
self, processor: ParagraphIndexProcessor, fake_executor_cls: type
|
||||
) -> None:
|
||||
preview_items = [PreviewDetail(content="chunk-1")]
|
||||
future = Mock()
|
||||
executor = fake_executor_cls(future)
|
||||
|
||||
with (
|
||||
patch("concurrent.futures.ThreadPoolExecutor", return_value=executor),
|
||||
patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]),
|
||||
):
|
||||
with pytest.raises(ValueError, match="timeout"):
|
||||
processor.generate_summary_preview("tenant-1", preview_items, {"enable": True})
|
||||
|
||||
future.cancel.assert_called_once()
|
||||
|
||||
def test_generate_summary_validates_input(self) -> None:
|
||||
with pytest.raises(ValueError, match="must be enabled"):
|
||||
ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": False})
|
||||
|
||||
with pytest.raises(ValueError, match="model_name and model_provider_name"):
|
||||
ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True})
|
||||
|
||||
def test_generate_summary_text_only_flow(self) -> None:
|
||||
model_instance = Mock()
|
||||
model_instance.credentials = {"k": "v"}
|
||||
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[])
|
||||
model_instance.invoke_llm.return_value = self._llm_result("text summary")
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
|
||||
return_value=model_instance,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota",
|
||||
side_effect=RuntimeError("quota"),
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
summary, usage = ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
|
||||
document_language="English",
|
||||
)
|
||||
|
||||
assert summary == "text summary"
|
||||
assert isinstance(usage, LLMUsage)
|
||||
mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota")
|
||||
|
||||
def test_generate_summary_handles_vision_and_image_conversion(self) -> None:
|
||||
model_instance = Mock()
|
||||
model_instance.credentials = {"k": "v"}
|
||||
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
|
||||
features=[ModelFeature.VISION]
|
||||
)
|
||||
model_instance.invoke_llm.return_value = self._llm_result("vision summary")
|
||||
image_file = SimpleNamespace()
|
||||
image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png")
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
|
||||
return_value=model_instance,
|
||||
),
|
||||
patch.object(
|
||||
ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[image_file]
|
||||
),
|
||||
patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[]) as mock_extract_text,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content",
|
||||
return_value=image_content,
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"),
|
||||
):
|
||||
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
|
||||
segment_id="seg-1",
|
||||
)
|
||||
|
||||
assert summary == "vision summary"
|
||||
mock_extract_text.assert_not_called()
|
||||
|
||||
def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None:
|
||||
model_instance = Mock()
|
||||
model_instance.credentials = {"k": "v"}
|
||||
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
|
||||
features=[ModelFeature.VISION]
|
||||
)
|
||||
model_instance.invoke_llm.return_value = object()
|
||||
image_file = SimpleNamespace()
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ModelInstance",
|
||||
return_value=model_instance,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.DEFAULT_GENERATOR_SUMMARY_PROMPT",
|
||||
"Prompt {missing}",
|
||||
),
|
||||
patch.object(ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[]),
|
||||
patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[image_file]),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content",
|
||||
side_effect=RuntimeError("bad image"),
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock()
|
||||
with pytest.raises(ValueError, match="Expected LLMResult"):
|
||||
ParagraphIndexProcessor.generate_summary(
|
||||
"tenant-1",
|
||||
"text content",
|
||||
{"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"},
|
||||
)
|
||||
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Failed to convert image file to prompt message content: %s", "bad image"
|
||||
)
|
||||
|
||||
def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None:
|
||||
text = (
|
||||
" "
|
||||
" "
|
||||
""
|
||||
)
|
||||
image_upload = SimpleNamespace(
|
||||
id="11111111-1111-1111-1111-111111111111",
|
||||
tenant_id="tenant-1",
|
||||
name="image.png",
|
||||
mime_type="image/png",
|
||||
extension="png",
|
||||
source_url="",
|
||||
size=1,
|
||||
key="key",
|
||||
)
|
||||
non_image_upload = SimpleNamespace(
|
||||
id="22222222-2222-2222-2222-222222222222",
|
||||
tenant_id="tenant-1",
|
||||
name="file.txt",
|
||||
mime_type="text/plain",
|
||||
extension="txt",
|
||||
source_url="",
|
||||
size=1,
|
||||
key="key",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.all.return_value = [image_upload, non_image_upload]
|
||||
session = Mock()
|
||||
session.query.return_value = query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping",
|
||||
return_value=SimpleNamespace(id="file-1"),
|
||||
) as mock_builder,
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text)
|
||||
|
||||
assert len(files) == 1
|
||||
assert mock_builder.call_count == 1
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None:
|
||||
assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here") == []
|
||||
|
||||
def test_extract_images_from_text_logs_when_build_fails(self) -> None:
|
||||
text = ""
|
||||
image_upload = SimpleNamespace(
|
||||
id="11111111-1111-1111-1111-111111111111",
|
||||
tenant_id="tenant-1",
|
||||
name="image.png",
|
||||
mime_type="image/png",
|
||||
extension="png",
|
||||
source_url="",
|
||||
size=1,
|
||||
key="key",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.all.return_value = [image_upload]
|
||||
session = Mock()
|
||||
session.query.return_value = query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping",
|
||||
side_effect=RuntimeError("build failed"),
|
||||
),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text)
|
||||
|
||||
assert files == []
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_extract_images_from_segment_attachments(self) -> None:
|
||||
image_upload = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="image",
|
||||
extension="png",
|
||||
mime_type="image/png",
|
||||
source_url="",
|
||||
size=1,
|
||||
key="k1",
|
||||
)
|
||||
bad_upload = SimpleNamespace(
|
||||
id="file-2",
|
||||
name="broken",
|
||||
extension=None,
|
||||
mime_type="image/png",
|
||||
source_url="",
|
||||
size=1,
|
||||
key="k2",
|
||||
)
|
||||
non_image_upload = SimpleNamespace(
|
||||
id="file-3",
|
||||
name="text",
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
source_url="",
|
||||
size=1,
|
||||
key="k3",
|
||||
)
|
||||
execute_result = Mock()
|
||||
execute_result.all.return_value = [(None, image_upload), (None, bad_upload), (None, non_image_upload)]
|
||||
session = Mock()
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger,
|
||||
):
|
||||
files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1")
|
||||
|
||||
assert len(files) == 1
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_extract_images_from_segment_attachments_empty(self) -> None:
|
||||
execute_result = Mock()
|
||||
execute_result.all.return_value = []
|
||||
session = Mock()
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
with patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session):
|
||||
empty_files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1")
|
||||
|
||||
assert empty_files == []
|
||||
@ -0,0 +1,523 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||
|
||||
|
||||
class TestParentChildIndexProcessor:
|
||||
@pytest.fixture
|
||||
def processor(self) -> ParentChildIndexProcessor:
|
||||
return ParentChildIndexProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(self) -> Mock:
|
||||
dataset = Mock()
|
||||
dataset.id = "dataset-1"
|
||||
dataset.tenant_id = "tenant-1"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.is_multimodal = True
|
||||
return dataset
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_document(self) -> Mock:
|
||||
document = Mock()
|
||||
document.id = "doc-1"
|
||||
document.created_by = "user-1"
|
||||
document.dataset_process_rule_id = None
|
||||
return document
|
||||
|
||||
def _segmentation(self) -> SimpleNamespace:
|
||||
return SimpleNamespace(max_tokens=200, chunk_overlap=10, separator="\n")
|
||||
|
||||
def _paragraph_rules(self) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
parent_mode=ParentMode.PARAGRAPH,
|
||||
segmentation=self._segmentation(),
|
||||
subchunk_segmentation=self._segmentation(),
|
||||
)
|
||||
|
||||
def _full_doc_rules(self) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
parent_mode=ParentMode.FULL_DOC, segmentation=None, subchunk_segmentation=self._segmentation()
|
||||
)
|
||||
|
||||
def test_extract_forwards_automatic_flag(self, processor: ParentChildIndexProcessor) -> None:
|
||||
extract_setting = Mock()
|
||||
expected = [Document(page_content="chunk", metadata={})]
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.ExtractProcessor.extract"
|
||||
) as mock_extract:
|
||||
mock_extract.return_value = expected
|
||||
documents = processor.extract(extract_setting, process_rule_mode="hierarchical")
|
||||
|
||||
assert documents == expected
|
||||
mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True)
|
||||
|
||||
def test_transform_validates_process_rule(self, processor: ParentChildIndexProcessor) -> None:
|
||||
with pytest.raises(ValueError, match="No process rule found"):
|
||||
processor.transform([Document(page_content="text", metadata={})], process_rule=None)
|
||||
|
||||
with pytest.raises(ValueError, match="No rules found in process rule"):
|
||||
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"})
|
||||
|
||||
def test_transform_paragraph_requires_segmentation(self, processor: ParentChildIndexProcessor) -> None:
|
||||
rules = SimpleNamespace(parent_mode=ParentMode.PARAGRAPH, segmentation=None)
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", return_value=rules
|
||||
):
|
||||
with pytest.raises(ValueError, match="No segmentation found in rules"):
|
||||
processor.transform(
|
||||
[Document(page_content="text", metadata={})],
|
||||
process_rule={"mode": "custom", "rules": {"enabled": True}},
|
||||
)
|
||||
|
||||
def test_transform_paragraph_builds_parent_and_child_docs(self, processor: ParentChildIndexProcessor) -> None:
|
||||
splitter = Mock()
|
||||
splitter.split_documents.return_value = [
|
||||
Document(page_content=".parent", metadata={}),
|
||||
Document(page_content=" ", metadata={}),
|
||||
]
|
||||
parent_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"})
|
||||
child_docs = [ChildDocument(page_content="child-1", metadata={"dataset_id": "dataset-1"})]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate",
|
||||
return_value=self._paragraph_rules(),
|
||||
),
|
||||
patch.object(processor, "_get_splitter", return_value=splitter),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean",
|
||||
return_value=".parent",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch.object(
|
||||
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})]
|
||||
),
|
||||
patch.object(processor, "_split_child_nodes", return_value=child_docs),
|
||||
):
|
||||
result = processor.transform(
|
||||
[parent_document],
|
||||
process_rule={"mode": "custom", "rules": {"enabled": True}},
|
||||
preview=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].page_content == "parent"
|
||||
assert result[0].children == child_docs
|
||||
assert result[0].attachments is not None
|
||||
|
||||
def test_transform_preview_returns_after_ten_parent_chunks(self, processor: ParentChildIndexProcessor) -> None:
|
||||
splitter = Mock()
|
||||
splitter.split_documents.return_value = [Document(page_content=f"chunk-{i}", metadata={}) for i in range(10)]
|
||||
documents = [
|
||||
Document(page_content="doc-1", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}),
|
||||
Document(page_content="doc-2", metadata={"dataset_id": "dataset-1", "document_id": "doc-2"}),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate",
|
||||
return_value=self._paragraph_rules(),
|
||||
),
|
||||
patch.object(processor, "_get_splitter", return_value=splitter),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean",
|
||||
side_effect=lambda text, _: text,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch.object(processor, "_get_content_files", return_value=[]),
|
||||
patch.object(processor, "_split_child_nodes", return_value=[]),
|
||||
):
|
||||
result = processor.transform(
|
||||
documents,
|
||||
process_rule={"mode": "custom", "rules": {"enabled": True}},
|
||||
preview=True,
|
||||
)
|
||||
|
||||
assert len(result) == 10
|
||||
|
||||
def test_transform_full_doc_mode_trims_children_for_preview(self, processor: ParentChildIndexProcessor) -> None:
|
||||
docs = [
|
||||
Document(page_content="first", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}),
|
||||
Document(page_content="second", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}),
|
||||
]
|
||||
child_docs = [ChildDocument(page_content=f"child-{i}", metadata={}) for i in range(5)]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate",
|
||||
return_value=self._full_doc_rules(),
|
||||
),
|
||||
patch.object(
|
||||
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})]
|
||||
),
|
||||
patch.object(processor, "_split_child_nodes", return_value=child_docs),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.dify_config.CHILD_CHUNKS_PREVIEW_NUMBER",
|
||||
2,
|
||||
),
|
||||
):
|
||||
result = processor.transform(
|
||||
docs,
|
||||
process_rule={"mode": "hierarchical", "rules": {"enabled": True}},
|
||||
preview=True,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0].children or []) == 2
|
||||
assert result[0].attachments is not None
|
||||
|
||||
def test_load_creates_vectors_for_child_docs(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
parent_doc = Document(
|
||||
page_content="parent",
|
||||
metadata={},
|
||||
children=[
|
||||
ChildDocument(page_content="child-1", metadata={}),
|
||||
ChildDocument(page_content="child-2", metadata={}),
|
||||
],
|
||||
)
|
||||
multimodal_docs = [AttachmentDocument(page_content="image", metadata={})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls:
|
||||
vector = mock_vector_cls.return_value
|
||||
processor.load(dataset, [parent_doc], multimodal_documents=multimodal_docs)
|
||||
|
||||
assert vector.create.call_count == 1
|
||||
formatted_docs = vector.create.call_args[0][0]
|
||||
assert len(formatted_docs) == 2
|
||||
assert all(isinstance(doc, Document) for doc in formatted_docs)
|
||||
vector.create_multimodal.assert_called_once_with(multimodal_docs)
|
||||
|
||||
def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
delete_query = Mock()
|
||||
where_query = Mock()
|
||||
where_query.delete.return_value = 2
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value = where_query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
|
||||
):
|
||||
vector = mock_vector_cls.return_value
|
||||
processor.clean(
|
||||
dataset,
|
||||
["node-1"],
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=["child-1", "child-2"],
|
||||
)
|
||||
|
||||
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
|
||||
where_query.delete.assert_called_once_with(synchronize_session=False)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_clean_queries_child_ids_when_not_precomputed(
|
||||
self, processor: ParentChildIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
child_query = Mock()
|
||||
child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)]
|
||||
session = Mock()
|
||||
session.query.return_value = child_query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
|
||||
):
|
||||
vector = mock_vector_cls.return_value
|
||||
processor.clean(dataset, ["node-1"], delete_child_chunks=False)
|
||||
|
||||
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
|
||||
|
||||
def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
where_query = Mock()
|
||||
where_query.delete.return_value = 3
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value = where_query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
|
||||
):
|
||||
vector = mock_vector_cls.return_value
|
||||
processor.clean(dataset, None, delete_child_chunks=True)
|
||||
|
||||
vector.delete.assert_called_once()
|
||||
where_query.delete.assert_called_once_with(synchronize_session=False)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
segment_query = Mock()
|
||||
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
session = Mock()
|
||||
session.query.return_value = segment_query
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.session_factory.create_session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments"
|
||||
) as mock_summary,
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"),
|
||||
):
|
||||
processor.clean(dataset, ["node-1"], delete_summaries=True, precomputed_child_node_ids=[])
|
||||
|
||||
mock_summary.assert_called_once_with(dataset, ["seg-1"])
|
||||
|
||||
def test_clean_deletes_all_summaries_when_node_ids_missing(
|
||||
self, processor: ParentChildIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments"
|
||||
) as mock_summary,
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"),
|
||||
):
|
||||
processor.clean(dataset, None, delete_summaries=True)
|
||||
|
||||
mock_summary.assert_called_once_with(dataset, None)
|
||||
|
||||
def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8)
|
||||
low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2)
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve"
|
||||
) as mock_retrieve:
|
||||
mock_retrieve.return_value = [ok_result, low_result]
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {})
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "keep"
|
||||
assert docs[0].metadata["score"] == 0.8
|
||||
|
||||
def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None:
|
||||
rules = SimpleNamespace(subchunk_segmentation=None)
|
||||
|
||||
with pytest.raises(ValueError, match="No subchunk segmentation found"):
|
||||
processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None)
|
||||
|
||||
def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None:
|
||||
rules = SimpleNamespace(subchunk_segmentation=self._segmentation())
|
||||
splitter = Mock()
|
||||
splitter.split_documents.return_value = [
|
||||
Document(page_content=".child-1", metadata={}),
|
||||
Document(page_content=" ", metadata={}),
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(processor, "_get_splitter", return_value=splitter),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
):
|
||||
child_docs = processor._split_child_nodes(
|
||||
Document(page_content="parent", metadata={}), rules, "custom", None
|
||||
)
|
||||
|
||||
assert len(child_docs) == 1
|
||||
assert child_docs[0].page_content == "child-1"
|
||||
assert child_docs[0].metadata["doc_hash"] == "hash"
|
||||
|
||||
def test_index_creates_process_rule_segments_and_vectors(
|
||||
self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
parent_childs = SimpleNamespace(
|
||||
parent_mode=ParentMode.PARAGRAPH,
|
||||
parent_child_chunks=[
|
||||
SimpleNamespace(
|
||||
parent_content="parent text",
|
||||
child_contents=["child-1", "child-2"],
|
||||
files=[SimpleNamespace(id="file-1", filename="image.png")],
|
||||
)
|
||||
],
|
||||
)
|
||||
dataset_rule = SimpleNamespace(id="rule-1")
|
||||
session = Mock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate",
|
||||
return_value=parent_childs,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule",
|
||||
return_value=dataset_rule,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
|
||||
side_effect=lambda text: f"hash-{text}",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore"
|
||||
) as mock_store_cls,
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
|
||||
):
|
||||
processor.index(dataset, dataset_document, {"parent_child_chunks": []})
|
||||
|
||||
assert dataset_document.dataset_process_rule_id == "rule-1"
|
||||
session.add.assert_called_once_with(dataset_rule)
|
||||
session.flush.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
mock_store_cls.return_value.add_documents.assert_called_once()
|
||||
assert mock_vector_cls.return_value.create.call_count == 1
|
||||
mock_vector_cls.return_value.create_multimodal.assert_called_once()
|
||||
|
||||
def test_index_uses_content_files_when_files_missing(
|
||||
self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
parent_childs = SimpleNamespace(
|
||||
parent_mode=ParentMode.PARAGRAPH,
|
||||
parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)],
|
||||
)
|
||||
dataset_rule = SimpleNamespace(id="rule-1")
|
||||
session = Mock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate",
|
||||
return_value=parent_childs,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule",
|
||||
return_value=dataset_rule,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user",
|
||||
return_value=SimpleNamespace(id="user-1"),
|
||||
),
|
||||
patch.object(
|
||||
processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})]
|
||||
) as mock_files,
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore"),
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"),
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session),
|
||||
):
|
||||
processor.index(dataset, dataset_document, {"parent_child_chunks": []})
|
||||
|
||||
mock_files.assert_called_once()
|
||||
|
||||
def test_index_raises_when_account_missing(
|
||||
self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
parent_childs = SimpleNamespace(
|
||||
parent_mode=ParentMode.PARAGRAPH,
|
||||
parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate",
|
||||
return_value=parent_childs,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash",
|
||||
return_value="hash",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid account"):
|
||||
processor.index(dataset, dataset_document, {"parent_child_chunks": []})
|
||||
|
||||
def test_format_preview_returns_parent_child_structure(self, processor: ParentChildIndexProcessor) -> None:
|
||||
parent_childs = SimpleNamespace(
|
||||
parent_mode=ParentMode.PARAGRAPH,
|
||||
parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child-1", "child-2"])],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate",
|
||||
return_value=parent_childs,
|
||||
):
|
||||
preview = processor.format_preview({"parent_child_chunks": []})
|
||||
|
||||
assert preview["chunk_structure"] == "hierarchical_model"
|
||||
assert preview["parent_mode"] == ParentMode.PARAGRAPH
|
||||
assert preview["total_segments"] == 1
|
||||
|
||||
def test_generate_summary_preview_sets_summaries(self, processor: ParentChildIndexProcessor) -> None:
|
||||
preview_texts = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")]
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary",
|
||||
return_value=("summary", None),
|
||||
):
|
||||
result = processor.generate_summary_preview(
|
||||
"tenant-1", preview_texts, {"enable": True}, doc_language="English"
|
||||
)
|
||||
|
||||
assert all(item.summary == "summary" for item in result)
|
||||
|
||||
def test_generate_summary_preview_raises_when_worker_fails(self, processor: ParentChildIndexProcessor) -> None:
|
||||
preview_texts = [PreviewDetail(content="chunk-1")]
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary",
|
||||
side_effect=RuntimeError("summary failed"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Failed to generate summaries"):
|
||||
processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True})
|
||||
|
||||
def test_generate_summary_preview_falls_back_without_flask_context(
|
||||
self, processor: ParentChildIndexProcessor
|
||||
) -> None:
|
||||
preview_texts = [PreviewDetail(content="chunk-1")]
|
||||
fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app")))
|
||||
|
||||
with (
|
||||
patch("flask.current_app", fake_current_app),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary",
|
||||
return_value=("summary", None),
|
||||
),
|
||||
):
|
||||
result = processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True})
|
||||
|
||||
assert result[0].summary == "summary"
|
||||
|
||||
def test_generate_summary_preview_handles_timeout(
|
||||
self, processor: ParentChildIndexProcessor, fake_executor_cls: type
|
||||
) -> None:
|
||||
preview_texts = [PreviewDetail(content="chunk-1")]
|
||||
future = Mock()
|
||||
executor = fake_executor_cls(future)
|
||||
|
||||
with (
|
||||
patch("concurrent.futures.ThreadPoolExecutor", return_value=executor),
|
||||
patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]),
|
||||
):
|
||||
with pytest.raises(ValueError, match="timeout"):
|
||||
processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True})
|
||||
|
||||
future.cancel.assert_called_once()
|
||||
@ -0,0 +1,382 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
|
||||
|
||||
class _ImmediateThread:
|
||||
def __init__(self, target, args=(), kwargs=None):
|
||||
self._target = target
|
||||
self._args = args
|
||||
self._kwargs = kwargs or {}
|
||||
|
||||
def start(self) -> None:
|
||||
self._target(*self._args, **self._kwargs)
|
||||
|
||||
def join(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class TestQAIndexProcessor:
|
||||
@pytest.fixture
|
||||
def processor(self) -> QAIndexProcessor:
|
||||
return QAIndexProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(self) -> Mock:
|
||||
dataset = Mock()
|
||||
dataset.id = "dataset-1"
|
||||
dataset.tenant_id = "tenant-1"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.is_multimodal = True
|
||||
return dataset
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_document(self) -> Mock:
|
||||
document = Mock()
|
||||
document.id = "doc-1"
|
||||
document.created_by = "user-1"
|
||||
return document
|
||||
|
||||
@pytest.fixture
|
||||
def process_rule(self) -> dict:
|
||||
return {
|
||||
"mode": "custom",
|
||||
"rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}},
|
||||
}
|
||||
|
||||
def _rules(self) -> SimpleNamespace:
|
||||
segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n")
|
||||
return SimpleNamespace(segmentation=segmentation)
|
||||
|
||||
def test_extract_forwards_automatic_flag(self, processor: QAIndexProcessor) -> None:
|
||||
extract_setting = Mock()
|
||||
expected_docs = [Document(page_content="chunk", metadata={})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.ExtractProcessor.extract") as mock_extract:
|
||||
mock_extract.return_value = expected_docs
|
||||
|
||||
docs = processor.extract(extract_setting, process_rule_mode="automatic")
|
||||
|
||||
assert docs == expected_docs
|
||||
mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True)
|
||||
|
||||
def test_transform_rejects_none_process_rule(self, processor: QAIndexProcessor) -> None:
|
||||
with pytest.raises(ValueError, match="No process rule found"):
|
||||
processor.transform([Document(page_content="text", metadata={})], process_rule=None)
|
||||
|
||||
def test_transform_rejects_missing_rules_key(self, processor: QAIndexProcessor) -> None:
|
||||
with pytest.raises(ValueError, match="No rules found in process rule"):
|
||||
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"})
|
||||
|
||||
def test_transform_preview_calls_formatter_once(
|
||||
self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app
|
||||
) -> None:
|
||||
document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"})
|
||||
split_node = Document(page_content=".question", metadata={})
|
||||
splitter = Mock()
|
||||
splitter.split_documents.return_value = [split_node]
|
||||
|
||||
def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language):
|
||||
all_qa_documents.append(Document(page_content="Q1", metadata={"answer": "A1"}))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules()
|
||||
),
|
||||
patch.object(processor, "_get_splitter", return_value=splitter),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", return_value="clean text"
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols",
|
||||
side_effect=lambda text: text.lstrip("."),
|
||||
),
|
||||
patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format,
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app,
|
||||
):
|
||||
mock_current_app._get_current_object.return_value = fake_flask_app
|
||||
result = processor.transform(
|
||||
[document],
|
||||
process_rule=process_rule,
|
||||
preview=True,
|
||||
tenant_id="tenant-1",
|
||||
doc_language="English",
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].metadata["answer"] == "A1"
|
||||
mock_format.assert_called_once()
|
||||
|
||||
def test_transform_non_preview_uses_thread_batches(
|
||||
self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app
|
||||
) -> None:
|
||||
documents = [
|
||||
Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}),
|
||||
Document(page_content="doc-2", metadata={"document_id": "doc-2", "dataset_id": "dataset-1"}),
|
||||
]
|
||||
split_node = Document(page_content="question", metadata={})
|
||||
splitter = Mock()
|
||||
splitter.split_documents.return_value = [split_node]
|
||||
|
||||
def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language):
|
||||
all_qa_documents.append(Document(page_content=f"Q-{document_node.page_content}", metadata={"answer": "A"}))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules()
|
||||
),
|
||||
patch.object(processor, "_get_splitter", return_value=splitter),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean",
|
||||
side_effect=lambda text, _: text,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols",
|
||||
side_effect=lambda text: text,
|
||||
),
|
||||
patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format,
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app,
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.threading.Thread", side_effect=_ImmediateThread
|
||||
),
|
||||
):
|
||||
mock_current_app._get_current_object.return_value = fake_flask_app
|
||||
result = processor.transform(documents, process_rule=process_rule, preview=False, tenant_id="tenant-1")
|
||||
|
||||
assert len(result) == 2
|
||||
assert mock_format.call_count == 2
|
||||
|
||||
def test_format_by_template_validates_file_type(self, processor: QAIndexProcessor) -> None:
|
||||
not_csv_file = Mock(spec=FileStorage)
|
||||
not_csv_file.filename = "qa.txt"
|
||||
|
||||
with pytest.raises(ValueError, match="Only CSV files"):
|
||||
processor.format_by_template(not_csv_file)
|
||||
|
||||
def test_format_by_template_parses_csv_rows(self, processor: QAIndexProcessor) -> None:
|
||||
csv_file = Mock(spec=FileStorage)
|
||||
csv_file.filename = "qa.csv"
|
||||
dataframe = pd.DataFrame([["Q1", "A1"], ["Q2", "A2"]])
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=dataframe):
|
||||
docs = processor.format_by_template(csv_file)
|
||||
|
||||
assert [doc.page_content for doc in docs] == ["Q1", "Q2"]
|
||||
assert [doc.metadata["answer"] for doc in docs] == ["A1", "A2"]
|
||||
|
||||
def test_format_by_template_raises_on_empty_csv(self, processor: QAIndexProcessor) -> None:
|
||||
csv_file = Mock(spec=FileStorage)
|
||||
csv_file.filename = "qa.csv"
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=pd.DataFrame()):
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
processor.format_by_template(csv_file)
|
||||
|
||||
def test_format_by_template_raises_on_invalid_csv(self, processor: QAIndexProcessor) -> None:
|
||||
csv_file = Mock(spec=FileStorage)
|
||||
csv_file.filename = "qa.csv"
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.pd.read_csv", side_effect=Exception("bad csv")
|
||||
):
|
||||
with pytest.raises(ValueError, match="bad csv"):
|
||||
processor.format_by_template(csv_file)
|
||||
|
||||
def test_load_creates_vectors_for_high_quality_dataset(self, processor: QAIndexProcessor, dataset: Mock) -> None:
|
||||
docs = [Document(page_content="Q1", metadata={"answer": "A1"})]
|
||||
multimodal_docs = [AttachmentDocument(page_content="image", metadata={})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls:
|
||||
vector = mock_vector_cls.return_value
|
||||
processor.load(dataset, docs, multimodal_documents=multimodal_docs)
|
||||
|
||||
vector.create.assert_called_once_with(docs)
|
||||
vector.create_multimodal.assert_called_once_with(multimodal_docs)
|
||||
|
||||
def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
docs = [Document(page_content="Q1", metadata={"answer": "A1"})]
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls:
|
||||
processor.load(dataset, docs)
|
||||
|
||||
mock_vector_cls.assert_not_called()
|
||||
|
||||
def test_clean_handles_summary_deletion_and_vector_cleanup(
|
||||
self, processor: QAIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
mock_segment = SimpleNamespace(id="seg-1")
|
||||
mock_query = Mock()
|
||||
mock_query.filter.return_value.all.return_value = [mock_segment]
|
||||
mock_session = Mock()
|
||||
mock_session.query.return_value = mock_query
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.session_factory.create_session",
|
||||
return_value=session_context,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments"
|
||||
) as mock_summary,
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls,
|
||||
):
|
||||
vector = mock_vector_cls.return_value
|
||||
processor.clean(dataset, ["node-1"], delete_summaries=True)
|
||||
|
||||
mock_summary.assert_called_once_with(dataset, ["seg-1"])
|
||||
vector.delete_by_ids.assert_called_once_with(["node-1"])
|
||||
|
||||
def test_clean_handles_dataset_wide_cleanup(self, processor: QAIndexProcessor, dataset: Mock) -> None:
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments"
|
||||
) as mock_summary,
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls,
|
||||
):
|
||||
vector = mock_vector_cls.return_value
|
||||
processor.clean(dataset, None, delete_summaries=True)
|
||||
|
||||
mock_summary.assert_called_once_with(dataset, None)
|
||||
vector.delete.assert_called_once()
|
||||
|
||||
def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None:
|
||||
result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9)
|
||||
result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1)
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve:
|
||||
mock_retrieve.return_value = [result_ok, result_low]
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "accepted"
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
|
||||
def test_index_adds_documents_and_vectors_for_high_quality(
|
||||
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
qa_chunks = SimpleNamespace(
|
||||
qa_chunks=[
|
||||
SimpleNamespace(question="Q1", answer="A1"),
|
||||
SimpleNamespace(question="Q2", answer="A2"),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate",
|
||||
return_value=qa_chunks,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
|
||||
),
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore") as mock_store_cls,
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls,
|
||||
):
|
||||
processor.index(dataset, dataset_document, {"qa_chunks": []})
|
||||
|
||||
mock_store_cls.return_value.add_documents.assert_called_once()
|
||||
mock_vector_cls.return_value.create.assert_called_once()
|
||||
|
||||
def test_index_requires_high_quality(
|
||||
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
dataset.indexing_technique = "economy"
|
||||
qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate",
|
||||
return_value=qa_chunks,
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
|
||||
),
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="must be high quality"):
|
||||
processor.index(dataset, dataset_document, {"qa_chunks": []})
|
||||
|
||||
def test_format_preview_returns_qa_preview(self, processor: QAIndexProcessor) -> None:
|
||||
qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")])
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate",
|
||||
return_value=qa_chunks,
|
||||
):
|
||||
preview = processor.format_preview({"qa_chunks": []})
|
||||
|
||||
assert preview["chunk_structure"] == "qa_model"
|
||||
assert preview["total_segments"] == 1
|
||||
assert preview["qa_preview"] == [{"question": "Q1", "answer": "A1"}]
|
||||
|
||||
def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None:
|
||||
preview_items = [PreviewDetail(content="Q1")]
|
||||
assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items
|
||||
|
||||
def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None:
|
||||
all_qa_documents: list[Document] = []
|
||||
blank_document = Document(page_content=" ", metadata={})
|
||||
|
||||
processor._format_qa_document(fake_flask_app, "tenant-1", blank_document, all_qa_documents, "English")
|
||||
|
||||
assert all_qa_documents == []
|
||||
|
||||
def test_format_qa_document_builds_question_answer_documents(
|
||||
self, processor: QAIndexProcessor, fake_flask_app
|
||||
) -> None:
|
||||
all_qa_documents: list[Document] = []
|
||||
source_document = Document(page_content="source text", metadata={"origin": "doc-1"})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document",
|
||||
return_value="Q1: What is this?\nA1: A test.\nQ2: Why?\nA2: Coverage.",
|
||||
),
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash"
|
||||
),
|
||||
):
|
||||
processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English")
|
||||
|
||||
assert len(all_qa_documents) == 2
|
||||
assert all_qa_documents[0].page_content == "What is this?"
|
||||
assert all_qa_documents[0].metadata["answer"] == "A test."
|
||||
assert all_qa_documents[1].metadata["answer"] == "Coverage."
|
||||
|
||||
def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None:
|
||||
all_qa_documents: list[Document] = []
|
||||
source_document = Document(page_content="source text", metadata={"origin": "doc-1"})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document",
|
||||
side_effect=RuntimeError("llm failure"),
|
||||
),
|
||||
patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger,
|
||||
):
|
||||
processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English")
|
||||
|
||||
assert all_qa_documents == []
|
||||
mock_logger.exception.assert_called_once_with("Failed to format qa document")
|
||||
|
||||
def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None:
|
||||
parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n")
|
||||
|
||||
assert parsed == [{"question": "First?", "answer": "One."}, {"question": "Second?", "answer": "Two."}]
|
||||
@ -0,0 +1,291 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
|
||||
|
||||
class _ForwardingBaseIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting, **kwargs):
|
||||
return super().extract(extract_setting, **kwargs)
|
||||
|
||||
def transform(self, documents, current_user=None, **kwargs):
|
||||
return super().transform(documents, current_user=current_user, **kwargs)
|
||||
|
||||
def generate_summary_preview(self, tenant_id, preview_texts, summary_index_setting, doc_language=None):
|
||||
return super().generate_summary_preview(
|
||||
tenant_id=tenant_id,
|
||||
preview_texts=preview_texts,
|
||||
summary_index_setting=summary_index_setting,
|
||||
doc_language=doc_language,
|
||||
)
|
||||
|
||||
def load(self, dataset, documents, multimodal_documents=None, with_keywords=True, **kwargs):
|
||||
return super().load(
|
||||
dataset=dataset,
|
||||
documents=documents,
|
||||
multimodal_documents=multimodal_documents,
|
||||
with_keywords=with_keywords,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def clean(self, dataset, node_ids, with_keywords=True, **kwargs):
|
||||
return super().clean(dataset=dataset, node_ids=node_ids, with_keywords=with_keywords, **kwargs)
|
||||
|
||||
def index(self, dataset, document, chunks):
|
||||
return super().index(dataset=dataset, document=document, chunks=chunks)
|
||||
|
||||
def format_preview(self, chunks):
|
||||
return super().format_preview(chunks)
|
||||
|
||||
def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model):
|
||||
return super().retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
query=query,
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
)
|
||||
|
||||
|
||||
class TestBaseIndexProcessor:
|
||||
@pytest.fixture
|
||||
def processor(self) -> _ForwardingBaseIndexProcessor:
|
||||
return _ForwardingBaseIndexProcessor()
|
||||
|
||||
def test_abstract_methods_raise_not_implemented(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.extract(Mock())
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.transform([])
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {})
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.load(Mock(), [])
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.clean(Mock(), None)
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.index(Mock(), Mock(), {})
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.format_preview([])
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.retrieve("semantic_search", "q", Mock(), 3, 0.5, {})
|
||||
|
||||
def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH", 1000
|
||||
):
|
||||
with pytest.raises(ValueError, match="between 50 and 1000"):
|
||||
processor._get_splitter("custom", 49, 0, "", None)
|
||||
with pytest.raises(ValueError, match="between 50 and 1000"):
|
||||
processor._get_splitter("custom", 1001, 0, "", None)
|
||||
|
||||
def test_get_splitter_custom_mode_uses_fixed_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
fixed_splitter = Mock()
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.FixedRecursiveCharacterTextSplitter.from_encoder",
|
||||
return_value=fixed_splitter,
|
||||
) as mock_fixed:
|
||||
splitter = processor._get_splitter("hierarchical", 120, 10, "\\n\\n", None)
|
||||
|
||||
assert splitter is fixed_splitter
|
||||
assert mock_fixed.call_args.kwargs["fixed_separator"] == "\n\n"
|
||||
assert mock_fixed.call_args.kwargs["chunk_size"] == 120
|
||||
|
||||
def test_get_splitter_automatic_mode_uses_enhance_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
auto_splitter = Mock()
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.EnhanceRecursiveCharacterTextSplitter.from_encoder",
|
||||
return_value=auto_splitter,
|
||||
) as mock_enhance:
|
||||
splitter = processor._get_splitter("automatic", 0, 0, "", None)
|
||||
|
||||
assert splitter is auto_splitter
|
||||
assert "chunk_size" in mock_enhance.call_args.kwargs
|
||||
|
||||
def test_extract_markdown_images(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
markdown = "text  and "
|
||||
images = processor._extract_markdown_images(markdown)
|
||||
assert images == ["https://a/img.png", "/files/123/file-preview"]
|
||||
|
||||
def test_get_content_files_without_images_returns_empty(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
document = Document(page_content="no image markdown", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
|
||||
assert processor._get_content_files(document) == []
|
||||
|
||||
def test_get_content_files_handles_all_sources_and_duplicates(
|
||||
self, processor: _ForwardingBaseIndexProcessor
|
||||
) -> None:
|
||||
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
|
||||
images = [
|
||||
"/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview",
|
||||
"/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview",
|
||||
"/files/bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb/file-preview",
|
||||
"/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png",
|
||||
"https://example.com/remote.png?x=1",
|
||||
]
|
||||
upload_a = SimpleNamespace(id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", name="a.png")
|
||||
upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png")
|
||||
upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png")
|
||||
upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png")
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
|
||||
with (
|
||||
patch.object(processor, "_extract_markdown_images", return_value=images),
|
||||
patch.object(processor, "_download_tool_file", return_value="tool-upload-id") as mock_tool_download,
|
||||
patch.object(processor, "_download_image", return_value="remote-upload-id") as mock_image_download,
|
||||
patch("core.rag.index_processor.index_processor_base.db.session", db_session),
|
||||
):
|
||||
files = processor._get_content_files(document, current_user=Mock())
|
||||
|
||||
assert len(files) == 5
|
||||
assert all(isinstance(file, AttachmentDocument) for file in files)
|
||||
assert files[0].metadata["doc_type"] == DocType.IMAGE
|
||||
assert files[0].metadata["document_id"] == "doc-1"
|
||||
assert files[0].metadata["dataset_id"] == "ds-1"
|
||||
assert files[0].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
assert files[1].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
mock_tool_download.assert_called_once()
|
||||
mock_image_download.assert_called_once()
|
||||
|
||||
def test_get_content_files_skips_tool_and_remote_download_without_user(
|
||||
self, processor: _ForwardingBaseIndexProcessor
|
||||
) -> None:
|
||||
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
|
||||
images = ["/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", "https://example.com/remote.png"]
|
||||
|
||||
with patch.object(processor, "_extract_markdown_images", return_value=images):
|
||||
files = processor._get_content_files(document, current_user=None)
|
||||
|
||||
assert files == []
|
||||
|
||||
def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
|
||||
images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"]
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.all.return_value = []
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
|
||||
with (
|
||||
patch.object(processor, "_extract_markdown_images", return_value=images),
|
||||
patch("core.rag.index_processor.index_processor_base.db.session", db_session),
|
||||
):
|
||||
files = processor._get_content_files(document)
|
||||
|
||||
assert files == []
|
||||
|
||||
def test_download_image_success_with_filename_from_content_disposition(
|
||||
self, processor: _ForwardingBaseIndexProcessor
|
||||
) -> None:
|
||||
response = Mock()
|
||||
response.headers = {
|
||||
"Content-Length": "4",
|
||||
"content-disposition": "attachment; filename=test-image.png",
|
||||
"content-type": "image/png",
|
||||
}
|
||||
response.raise_for_status.return_value = None
|
||||
response.iter_bytes.return_value = [b"data"]
|
||||
upload_result = SimpleNamespace(id="upload-id")
|
||||
|
||||
mock_db = Mock()
|
||||
mock_db.engine = Mock()
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response),
|
||||
patch("core.rag.index_processor.index_processor_base.db", mock_db),
|
||||
patch("services.file_service.FileService") as mock_file_service,
|
||||
):
|
||||
mock_file_service.return_value.upload_file.return_value = upload_result
|
||||
upload_id = processor._download_image("https://example.com/test.png", current_user=Mock())
|
||||
|
||||
assert upload_id == "upload-id"
|
||||
mock_file_service.return_value.upload_file.assert_called_once()
|
||||
|
||||
def test_download_image_validates_size_and_empty_content(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
too_large = Mock()
|
||||
too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"}
|
||||
too_large.raise_for_status.return_value = None
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large):
|
||||
assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None
|
||||
|
||||
empty = Mock()
|
||||
empty.headers = {"Content-Length": "0", "content-type": "image/png"}
|
||||
empty.raise_for_status.return_value = None
|
||||
empty.iter_bytes.return_value = []
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty):
|
||||
assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None
|
||||
|
||||
def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
response = Mock()
|
||||
response.headers = {"content-type": "image/png"}
|
||||
response.raise_for_status.return_value = None
|
||||
response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)]
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response):
|
||||
assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None
|
||||
|
||||
def test_download_image_handles_timeout_request_and_unexpected_errors(
|
||||
self, processor: _ForwardingBaseIndexProcessor
|
||||
) -> None:
|
||||
request = httpx.Request("GET", "https://example.com/image.png")
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
|
||||
side_effect=httpx.TimeoutException("timeout"),
|
||||
):
|
||||
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
|
||||
side_effect=httpx.RequestError("bad request", request=request),
|
||||
):
|
||||
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
|
||||
side_effect=RuntimeError("unexpected"),
|
||||
):
|
||||
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
|
||||
|
||||
def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = None
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.db.session", db_session):
|
||||
assert processor._download_tool_file("tool-id", current_user=Mock()) is None
|
||||
|
||||
def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png")
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = tool_file
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
mock_db = Mock()
|
||||
mock_db.session = db_session
|
||||
mock_db.engine = Mock()
|
||||
upload_result = SimpleNamespace(id="upload-id")
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.index_processor_base.db", mock_db),
|
||||
patch("core.rag.index_processor.index_processor_base.storage.load_once", return_value=b"blob") as mock_load,
|
||||
patch("services.file_service.FileService") as mock_file_service,
|
||||
):
|
||||
mock_file_service.return_value.upload_file.return_value = upload_result
|
||||
result = processor._download_tool_file("tool-id", current_user=Mock())
|
||||
|
||||
assert result == "upload-id"
|
||||
mock_load.assert_called_once_with("k1")
|
||||
mock_file_service.return_value.upload_file.assert_called_once()
|
||||
@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
|
||||
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
|
||||
|
||||
|
||||
class TestIndexProcessorFactory:
|
||||
def test_requires_index_type(self) -> None:
|
||||
factory = IndexProcessorFactory(index_type=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Index type must be specified"):
|
||||
factory.init_index_processor()
|
||||
|
||||
def test_builds_paragraph_processor(self) -> None:
|
||||
factory = IndexProcessorFactory(index_type=IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
processor = factory.init_index_processor()
|
||||
|
||||
assert isinstance(processor, ParagraphIndexProcessor)
|
||||
|
||||
def test_builds_qa_processor(self) -> None:
|
||||
factory = IndexProcessorFactory(index_type=IndexStructureType.QA_INDEX)
|
||||
|
||||
processor = factory.init_index_processor()
|
||||
|
||||
assert isinstance(processor, QAIndexProcessor)
|
||||
|
||||
def test_builds_parent_child_processor(self) -> None:
|
||||
factory = IndexProcessorFactory(index_type=IndexStructureType.PARENT_CHILD_INDEX)
|
||||
|
||||
processor = factory.init_index_processor()
|
||||
|
||||
assert isinstance(processor, ParentChildIndexProcessor)
|
||||
|
||||
def test_rejects_unsupported_index_type(self) -> None:
|
||||
factory = IndexProcessorFactory(index_type="unsupported")
|
||||
|
||||
with pytest.raises(ValueError, match="is not supported"):
|
||||
factory.init_index_processor()
|
||||
@ -12,13 +12,18 @@ All tests use mocking to avoid external dependencies and ensure fast, reliable e
|
||||
Tests follow the Arrange-Act-Assert pattern for clarity.
|
||||
"""
|
||||
|
||||
from operator import itemgetter
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
|
||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||
from core.rag.rerank.rerank_factory import RerankRunnerFactory
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
@ -26,7 +31,7 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
|
||||
|
||||
def create_mock_model_instance():
|
||||
def create_mock_model_instance() -> ModelInstance:
|
||||
"""Create a properly configured mock ModelInstance for reranking tests."""
|
||||
mock_instance = Mock(spec=ModelInstance)
|
||||
# Setup provider_model_bundle chain for check_model_support_vision
|
||||
@ -59,14 +64,7 @@ class TestRerankModelRunner:
|
||||
@pytest.fixture
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for reranking."""
|
||||
mock_instance = Mock(spec=ModelInstance)
|
||||
# Setup provider_model_bundle chain for check_model_support_vision
|
||||
mock_instance.provider_model_bundle = Mock()
|
||||
mock_instance.provider_model_bundle.configuration = Mock()
|
||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||
mock_instance.provider = "test-provider"
|
||||
mock_instance.model_name = "test-model"
|
||||
return mock_instance
|
||||
return create_mock_model_instance()
|
||||
|
||||
@pytest.fixture
|
||||
def rerank_runner(self, mock_model_instance):
|
||||
@ -382,6 +380,206 @@ class TestRerankModelRunner:
|
||||
assert call_kwargs["user"] == "user123"
|
||||
|
||||
|
||||
class _ForwardingBaseRerankRunner(BaseRerankRunner):
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[Document],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
) -> list[Document]:
|
||||
return super().run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
query_type=query_type,
|
||||
)
|
||||
|
||||
|
||||
class TestBaseRerankRunner:
|
||||
def test_run_raises_not_implemented(self):
|
||||
runner = _ForwardingBaseRerankRunner()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner.run(query="python", documents=[])
|
||||
|
||||
|
||||
class TestRerankModelRunnerMultimodal:
|
||||
@pytest.fixture
|
||||
def mock_model_instance(self):
|
||||
return create_mock_model_instance()
|
||||
|
||||
@pytest.fixture
|
||||
def rerank_runner(self, mock_model_instance):
|
||||
return RerankModelRunner(rerank_model_instance=mock_model_instance)
|
||||
|
||||
def test_run_returns_original_documents_for_non_text_query_without_vision_support(
|
||||
self, rerank_runner, mock_model_instance
|
||||
):
|
||||
documents = [
|
||||
Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"),
|
||||
]
|
||||
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY)
|
||||
|
||||
assert result == documents
|
||||
mock_model_instance.invoke_rerank.assert_not_called()
|
||||
|
||||
def test_run_uses_multimodal_path_when_vision_support_is_enabled(self, rerank_runner):
|
||||
documents = [
|
||||
Document(page_content="doc", metadata={"doc_id": "doc1", "source": "wiki"}, provider="dify"),
|
||||
]
|
||||
rerank_result = RerankResult(
|
||||
model="rerank-model",
|
||||
docs=[RerankDocument(index=0, text="doc", score=0.88)],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm,
|
||||
patch.object(
|
||||
rerank_runner,
|
||||
"fetch_multimodal_rerank",
|
||||
return_value=(rerank_result, documents),
|
||||
) as mock_multimodal,
|
||||
):
|
||||
mock_mm.return_value.check_model_support_vision.return_value = True
|
||||
result = rerank_runner.run(query="python", documents=documents, query_type=QueryType.TEXT_QUERY)
|
||||
|
||||
mock_multimodal.assert_called_once()
|
||||
assert len(result) == 1
|
||||
assert result[0].metadata["score"] == 0.88
|
||||
|
||||
def test_fetch_multimodal_rerank_builds_docs_and_calls_text_rerank(self, rerank_runner):
|
||||
image_doc = Document(
|
||||
page_content="image-content",
|
||||
metadata={"doc_id": "img-1", "doc_type": DocType.IMAGE},
|
||||
provider="dify",
|
||||
)
|
||||
text_doc = Document(
|
||||
page_content="text-content",
|
||||
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
|
||||
provider="dify",
|
||||
)
|
||||
external_doc = Document(
|
||||
page_content="external-content",
|
||||
metadata={},
|
||||
provider="external",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.first.return_value = SimpleNamespace(key="image-key")
|
||||
rerank_result = RerankResult(model="rerank-model", docs=[])
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
|
||||
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once,
|
||||
patch.object(
|
||||
rerank_runner,
|
||||
"fetch_text_rerank",
|
||||
return_value=(rerank_result, [image_doc, text_doc, external_doc]),
|
||||
) as mock_text_rerank,
|
||||
):
|
||||
result, unique_documents = rerank_runner.fetch_multimodal_rerank(
|
||||
query="python",
|
||||
documents=[image_doc, text_doc, external_doc, external_doc],
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
|
||||
assert result == rerank_result
|
||||
assert len(unique_documents) == 3
|
||||
mock_load_once.assert_called_once_with("image-key")
|
||||
text_rerank_call_args = mock_text_rerank.call_args.args
|
||||
assert len(text_rerank_call_args[1]) == 3
|
||||
|
||||
def test_fetch_multimodal_rerank_skips_missing_image_upload(self, rerank_runner):
|
||||
image_doc = Document(
|
||||
page_content="image-content",
|
||||
metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE},
|
||||
provider="dify",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
rerank_result = RerankResult(model="rerank-model", docs=[])
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
|
||||
patch.object(
|
||||
rerank_runner,
|
||||
"fetch_text_rerank",
|
||||
return_value=(rerank_result, [image_doc]),
|
||||
) as mock_text_rerank,
|
||||
):
|
||||
result, unique_documents = rerank_runner.fetch_multimodal_rerank(
|
||||
query="python",
|
||||
documents=[image_doc],
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
|
||||
assert result == rerank_result
|
||||
assert unique_documents == [image_doc]
|
||||
docs_arg = mock_text_rerank.call_args.args[1]
|
||||
assert len(docs_arg) == 1
|
||||
|
||||
def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance):
|
||||
text_doc = Document(
|
||||
page_content="text-content",
|
||||
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
|
||||
provider="dify",
|
||||
)
|
||||
query_chain = Mock()
|
||||
query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key")
|
||||
rerank_result = RerankResult(
|
||||
model="rerank-model",
|
||||
docs=[RerankDocument(index=0, text="text-content", score=0.77)],
|
||||
)
|
||||
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain),
|
||||
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
|
||||
):
|
||||
result, unique_documents = rerank_runner.fetch_multimodal_rerank(
|
||||
query="query-upload-id",
|
||||
documents=[text_doc],
|
||||
score_threshold=0.2,
|
||||
top_n=2,
|
||||
user="user-1",
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
)
|
||||
|
||||
assert result == rerank_result
|
||||
assert unique_documents == [text_doc]
|
||||
invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs
|
||||
assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE
|
||||
assert invoke_kwargs["docs"][0]["content"] == "text-content"
|
||||
assert invoke_kwargs["user"] == "user-1"
|
||||
|
||||
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
|
||||
query_chain = Mock()
|
||||
query_chain.where.return_value.first.return_value = None
|
||||
|
||||
with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain):
|
||||
with pytest.raises(ValueError, match="Upload file not found for query"):
|
||||
rerank_runner.fetch_multimodal_rerank(
|
||||
query="missing-upload-id",
|
||||
documents=[],
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
)
|
||||
|
||||
def test_fetch_multimodal_rerank_rejects_unsupported_query_type(self, rerank_runner):
|
||||
with pytest.raises(ValueError, match="is not supported"):
|
||||
rerank_runner.fetch_multimodal_rerank(
|
||||
query="python",
|
||||
documents=[],
|
||||
query_type="unsupported_query_type",
|
||||
)
|
||||
|
||||
|
||||
class TestWeightRerankRunner:
|
||||
"""Unit tests for WeightRerankRunner.
|
||||
|
||||
@ -512,34 +710,39 @@ class TestWeightRerankRunner:
|
||||
- TF-IDF scores are calculated correctly
|
||||
- Cosine similarity is computed for keyword vectors
|
||||
"""
|
||||
# Arrange: Create runner
|
||||
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config)
|
||||
|
||||
# Mock keyword extraction with specific keywords
|
||||
keyword_map = {
|
||||
"python programming": ["python", "programming"],
|
||||
"Python is a programming language": ["python", "programming", "language"],
|
||||
"JavaScript for web development": ["javascript", "web"],
|
||||
"Java object-oriented programming": ["java", "programming"],
|
||||
}
|
||||
mock_handler_instance = MagicMock()
|
||||
mock_handler_instance.extract_keywords.side_effect = [
|
||||
["python", "programming"], # query
|
||||
["python", "programming", "language"], # doc1
|
||||
["javascript", "web"], # doc2
|
||||
["java", "programming"], # doc3
|
||||
]
|
||||
mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text]
|
||||
mock_jieba_handler.return_value = mock_handler_instance
|
||||
|
||||
# Mock embedding
|
||||
mock_embedding_instance = MagicMock()
|
||||
mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance
|
||||
mock_cache_instance = MagicMock()
|
||||
mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4]
|
||||
mock_cache_embedding.return_value = mock_cache_instance
|
||||
|
||||
# Act: Run reranking
|
||||
query_scores = runner._calculate_keyword_score("python programming", sample_documents_with_vectors)
|
||||
vector_scores = runner._calculate_cosine(
|
||||
"tenant123", "python programming", sample_documents_with_vectors, weights_config.vector_setting
|
||||
)
|
||||
expected_scores = {
|
||||
doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score)
|
||||
for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores)
|
||||
}
|
||||
|
||||
result = runner.run(query="python programming", documents=sample_documents_with_vectors)
|
||||
|
||||
# Assert: Keywords are extracted and scores are calculated
|
||||
assert len(result) == 3
|
||||
# Document 1 should have highest keyword score (matches both query terms)
|
||||
# Document 3 should have medium score (matches one term)
|
||||
# Document 2 should have lowest score (matches no terms)
|
||||
expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)]
|
||||
assert [doc.metadata["doc_id"] for doc in result] == expected_order
|
||||
for doc in result:
|
||||
doc_id = doc.metadata["doc_id"]
|
||||
assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6)
|
||||
|
||||
def test_vector_score_calculation(
|
||||
self,
|
||||
@ -556,30 +759,42 @@ class TestWeightRerankRunner:
|
||||
- Cosine similarity is calculated with document vectors
|
||||
- Vector scores are properly normalized
|
||||
"""
|
||||
# Arrange: Create runner
|
||||
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config)
|
||||
|
||||
# Mock keyword extraction
|
||||
keyword_map = {
|
||||
"test query": ["test"],
|
||||
"Python is a programming language": ["python"],
|
||||
"JavaScript for web development": ["javascript"],
|
||||
"Java object-oriented programming": ["java"],
|
||||
}
|
||||
mock_handler_instance = MagicMock()
|
||||
mock_handler_instance.extract_keywords.return_value = ["test"]
|
||||
mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text]
|
||||
mock_jieba_handler.return_value = mock_handler_instance
|
||||
|
||||
# Mock embedding model
|
||||
mock_embedding_instance = MagicMock()
|
||||
mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance
|
||||
|
||||
# Mock cache embedding with specific query vector
|
||||
mock_cache_instance = MagicMock()
|
||||
query_vector = [0.2, 0.3, 0.4, 0.5]
|
||||
mock_cache_instance.embed_query.return_value = query_vector
|
||||
mock_cache_embedding.return_value = mock_cache_instance
|
||||
|
||||
# Act: Run reranking
|
||||
query_scores = runner._calculate_keyword_score("test query", sample_documents_with_vectors)
|
||||
vector_scores = runner._calculate_cosine(
|
||||
"tenant123", "test query", sample_documents_with_vectors, weights_config.vector_setting
|
||||
)
|
||||
expected_scores = {
|
||||
doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score)
|
||||
for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores)
|
||||
}
|
||||
|
||||
result = runner.run(query="test query", documents=sample_documents_with_vectors)
|
||||
|
||||
# Assert: Vector scores are calculated
|
||||
assert len(result) == 3
|
||||
# Verify cosine similarity was computed (doc2 vector is closest to query vector)
|
||||
expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)]
|
||||
assert [doc.metadata["doc_id"] for doc in result] == expected_order
|
||||
for doc in result:
|
||||
doc_id = doc.metadata["doc_id"]
|
||||
assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6)
|
||||
|
||||
def test_score_threshold_filtering_weighted(
|
||||
self,
|
||||
@ -742,28 +957,40 @@ class TestWeightRerankRunner:
|
||||
- Keyword weight (0.4) is applied to keyword scores
|
||||
- Combined score is the sum of weighted components
|
||||
"""
|
||||
# Arrange: Create runner with known weights
|
||||
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config)
|
||||
|
||||
# Mock keyword extraction
|
||||
keyword_map = {
|
||||
"test": ["test"],
|
||||
"Python is a programming language": ["python", "language"],
|
||||
"JavaScript for web development": ["javascript", "web"],
|
||||
"Java object-oriented programming": ["java", "programming"],
|
||||
}
|
||||
mock_handler_instance = MagicMock()
|
||||
mock_handler_instance.extract_keywords.return_value = ["test"]
|
||||
mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text]
|
||||
mock_jieba_handler.return_value = mock_handler_instance
|
||||
|
||||
# Mock embedding
|
||||
mock_embedding_instance = MagicMock()
|
||||
mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance
|
||||
mock_cache_instance = MagicMock()
|
||||
mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4]
|
||||
mock_cache_embedding.return_value = mock_cache_instance
|
||||
|
||||
# Act: Run reranking
|
||||
query_scores = runner._calculate_keyword_score("test", sample_documents_with_vectors)
|
||||
vector_scores = runner._calculate_cosine(
|
||||
"tenant123", "test", sample_documents_with_vectors, weights_config.vector_setting
|
||||
)
|
||||
expected_scores = {
|
||||
doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score)
|
||||
for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores)
|
||||
}
|
||||
|
||||
result = runner.run(query="test", documents=sample_documents_with_vectors)
|
||||
|
||||
# Assert: Scores are combined with weights
|
||||
# Score = 0.6 * vector_score + 0.4 * keyword_score
|
||||
assert len(result) == 3
|
||||
assert all("score" in doc.metadata for doc in result)
|
||||
expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)]
|
||||
assert [doc.metadata["doc_id"] for doc in result] == expected_order
|
||||
for doc in result:
|
||||
doc_id = doc.metadata["doc_id"]
|
||||
assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6)
|
||||
|
||||
def test_existing_vector_score_in_metadata(
|
||||
self,
|
||||
@ -778,7 +1005,6 @@ class TestWeightRerankRunner:
|
||||
- If document already has a score in metadata, it's used
|
||||
- Cosine similarity calculation is skipped for such documents
|
||||
"""
|
||||
# Arrange: Documents with pre-existing scores
|
||||
documents = [
|
||||
Document(
|
||||
page_content="Content with existing score",
|
||||
@ -790,24 +1016,29 @@ class TestWeightRerankRunner:
|
||||
|
||||
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config)
|
||||
|
||||
# Mock keyword extraction
|
||||
keyword_map = {
|
||||
"test": ["test"],
|
||||
"Content with existing score": ["test"],
|
||||
}
|
||||
mock_handler_instance = MagicMock()
|
||||
mock_handler_instance.extract_keywords.return_value = ["test"]
|
||||
mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text]
|
||||
mock_jieba_handler.return_value = mock_handler_instance
|
||||
|
||||
# Mock embedding
|
||||
mock_embedding_instance = MagicMock()
|
||||
mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance
|
||||
mock_cache_instance = MagicMock()
|
||||
mock_cache_instance.embed_query.return_value = [0.1, 0.2]
|
||||
mock_cache_embedding.return_value = mock_cache_instance
|
||||
|
||||
# Act: Run reranking
|
||||
query_scores = runner._calculate_keyword_score("test", documents)
|
||||
vector_scores = runner._calculate_cosine("tenant123", "test", documents, weights_config.vector_setting)
|
||||
expected_score = 0.6 * vector_scores[0] + 0.4 * query_scores[0]
|
||||
|
||||
result = runner.run(query="test", documents=documents)
|
||||
|
||||
# Assert: Existing score is used in calculation
|
||||
assert len(result) == 1
|
||||
# The final score should incorporate the existing score (0.95) with vector weight (0.6)
|
||||
assert result[0].metadata["doc_id"] == "doc1"
|
||||
assert result[0].metadata["score"] == pytest.approx(expected_score, rel=1e-6)
|
||||
|
||||
|
||||
class TestRerankRunnerFactory:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,873 +0,0 @@
|
||||
"""
|
||||
Unit tests for DatasetRetrieval.process_metadata_filter_func.
|
||||
|
||||
This module provides comprehensive test coverage for the process_metadata_filter_func
|
||||
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
|
||||
filter expressions based on metadata filtering conditions.
|
||||
|
||||
Conditions Tested:
|
||||
==================
|
||||
1. **String Conditions**: contains, not contains, start with, end with
|
||||
2. **Equality Conditions**: is / =, is not / ≠
|
||||
3. **Null Conditions**: empty, not empty
|
||||
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
|
||||
5. **List Conditions**: in
|
||||
6. **Edge Cases**: None values, different data types (str, int, float)
|
||||
|
||||
Test Architecture:
|
||||
==================
|
||||
- Direct instantiation of DatasetRetrieval
|
||||
- Mocking of DatasetDocument model attributes
|
||||
- Verification of SQLAlchemy filter expressions
|
||||
- Follows Arrange-Act-Assert (AAA) pattern
|
||||
|
||||
Running Tests:
|
||||
==============
|
||||
# Run all tests in this module
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
|
||||
|
||||
# Run a specific test
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
|
||||
TestProcessMetadataFilterFunc::test_contains_condition -v
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
|
||||
class TestProcessMetadataFilterFunc:
|
||||
"""
|
||||
Comprehensive test suite for process_metadata_filter_func method.
|
||||
|
||||
This test class validates all metadata filtering conditions supported by
|
||||
the DatasetRetrieval class, including string operations, numeric comparisons,
|
||||
null checks, and list operations.
|
||||
|
||||
Method Signature:
|
||||
==================
|
||||
def process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
) -> list:
|
||||
|
||||
The method builds SQLAlchemy filter expressions by:
|
||||
1. Validating value is not None (except for empty/not empty conditions)
|
||||
2. Using DatasetDocument.doc_metadata JSON field operations
|
||||
3. Adding appropriate SQLAlchemy expressions to the filters list
|
||||
4. Returning the updated filters list
|
||||
|
||||
Mocking Strategy:
|
||||
==================
|
||||
- Mock DatasetDocument.doc_metadata to avoid database dependencies
|
||||
- Verify filter expressions are created correctly
|
||||
- Test with various data types (str, int, float, list)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval(self):
|
||||
"""
|
||||
Create a DatasetRetrieval instance for testing.
|
||||
|
||||
Returns:
|
||||
DatasetRetrieval: Instance to test process_metadata_filter_func
|
||||
"""
|
||||
return DatasetRetrieval()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_metadata(self):
|
||||
"""
|
||||
Mock the DatasetDocument.doc_metadata JSON field.
|
||||
|
||||
The method uses DatasetDocument.doc_metadata[metadata_name] to access
|
||||
JSON fields. We mock this to avoid database dependencies.
|
||||
|
||||
Returns:
|
||||
Mock: Mocked doc_metadata attribute
|
||||
"""
|
||||
mock_metadata_field = MagicMock()
|
||||
|
||||
# Create mock for string access
|
||||
mock_string_access = MagicMock()
|
||||
mock_string_access.like = MagicMock()
|
||||
mock_string_access.notlike = MagicMock()
|
||||
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.in_ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for float access (for numeric comparisons)
|
||||
mock_float_access = MagicMock()
|
||||
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for null checks
|
||||
mock_null_access = MagicMock()
|
||||
mock_null_access.is_ = MagicMock(return_value=MagicMock())
|
||||
mock_null_access.isnot = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Setup __getitem__ to return appropriate mock based on usage
|
||||
def getitem_side_effect(name):
|
||||
if name in ["author", "title", "category"]:
|
||||
return mock_string_access
|
||||
elif name in ["year", "price", "rating"]:
|
||||
return mock_float_access
|
||||
else:
|
||||
return mock_string_access
|
||||
|
||||
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
|
||||
mock_metadata_field.as_string.return_value = mock_string_access
|
||||
mock_metadata_field.as_float.return_value = mock_float_access
|
||||
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
|
||||
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
|
||||
|
||||
return mock_metadata_field
|
||||
|
||||
# ==================== String Condition Tests ====================
|
||||
|
||||
def test_contains_condition_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'contains' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "John"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_contains_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with NOT LIKE expression
|
||||
- Pattern matching uses %value% syntax with negation
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not contains"
|
||||
metadata_name = "title"
|
||||
value = "banned"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_start_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'start with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "start with"
|
||||
metadata_name = "category"
|
||||
value = "tech"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_end_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'end with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "end with"
|
||||
metadata_name = "filename"
|
||||
value = ".pdf"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Equality Condition Tests ====================
|
||||
|
||||
def test_is_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' (=) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with equality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = "Jane Doe"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_equals_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test '=' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is' condition
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "="
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_int_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with integer value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_float_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with float value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "price"
|
||||
value = 19.99
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' (≠) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with inequality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "author"
|
||||
value = "Unknown"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test '≠' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is not' condition
|
||||
- Inequality expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≠"
|
||||
metadata_name = "category"
|
||||
value = "archived"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_numeric_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' condition with numeric value.
|
||||
|
||||
Verifies:
|
||||
- Numeric inequality comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Null Condition Tests ====================
|
||||
|
||||
def test_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'empty' condition (null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "empty"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not empty' condition (not null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NOT NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not empty"
|
||||
metadata_name = "description"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Numeric Comparison Tests ====================
|
||||
|
||||
def test_before_condition(self, retrieval):
|
||||
"""
|
||||
Test 'before' (<) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "before"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '<' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'before' condition
|
||||
- Less than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "price"
|
||||
value = 100.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_after_condition(self, retrieval):
|
||||
"""
|
||||
Test 'after' (>) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "after"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '>' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'after' condition
|
||||
- Greater than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≤' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≤"
|
||||
metadata_name = "price"
|
||||
value = 50.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '<=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≤' condition
|
||||
- Less than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<="
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≥' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≥"
|
||||
metadata_name = "rating"
|
||||
value = 3.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '>=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≥' condition
|
||||
- Greater than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== List/In Condition Tests ====================
|
||||
|
||||
def test_in_condition_with_comma_separated_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with comma-separated string value.
|
||||
|
||||
Verifies:
|
||||
- String is split into list
|
||||
- Whitespace is trimmed from each value
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "tech, science, AI "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_list_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with list value.
|
||||
|
||||
Verifies:
|
||||
- List is processed correctly
|
||||
- None values are filtered out
|
||||
- IN expression is created with valid values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "tags"
|
||||
value = ["python", "javascript", None, "golang"]
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_tuple_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with tuple value.
|
||||
|
||||
Verifies:
|
||||
- Tuple is processed like a list
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ("tech", "science", "ai")
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_empty_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with empty string value.
|
||||
|
||||
Verifies:
|
||||
- Empty string results in literal(False) filter
|
||||
- No valid values to match
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
# Verify it's a literal(False) expression
|
||||
# This is a bit tricky to test without access to the actual expression
|
||||
|
||||
def test_in_condition_with_only_whitespace(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with whitespace-only string value.
|
||||
|
||||
Verifies:
|
||||
- Whitespace-only string results in literal(False) filter
|
||||
- All values are stripped and filtered out
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = " , , "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_single_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with single non-comma string.
|
||||
|
||||
Verifies:
|
||||
- Single string is treated as single-item list
|
||||
- IN expression is created with one value
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Edge Case Tests ====================
|
||||
|
||||
def test_none_value_with_non_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with conditions that require value.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values (except empty/not empty)
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0 # No filter added
|
||||
|
||||
def test_none_value_with_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with 'is' (=) condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_none_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "year"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_existing_filters_preserved(self, retrieval):
|
||||
"""
|
||||
Test that existing filters are preserved.
|
||||
|
||||
Verifies:
|
||||
- Existing filters in the list are not removed
|
||||
- New filters are appended to the list
|
||||
"""
|
||||
existing_filter = MagicMock()
|
||||
filters = [existing_filter]
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 2
|
||||
assert filters[0] == existing_filter
|
||||
|
||||
def test_multiple_filters_accumulated(self, retrieval):
|
||||
"""
|
||||
Test multiple calls to accumulate filters.
|
||||
|
||||
Verifies:
|
||||
- Each call adds a new filter to the list
|
||||
- All filters are preserved across calls
|
||||
"""
|
||||
filters = []
|
||||
|
||||
# First filter
|
||||
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
|
||||
assert len(filters) == 1
|
||||
|
||||
# Second filter
|
||||
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
|
||||
assert len(filters) == 2
|
||||
|
||||
# Third filter
|
||||
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
|
||||
assert len(filters) == 3
|
||||
|
||||
def test_unknown_condition(self, retrieval):
|
||||
"""
|
||||
Test unknown/unsupported condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for unknown conditions
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "unknown_condition"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_empty_string_value_with_contains(self, retrieval):
|
||||
"""
|
||||
Test empty string value with 'contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filter is added even with empty string
|
||||
- LIKE expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_special_characters_in_value(self, retrieval):
|
||||
"""
|
||||
Test special characters in value string.
|
||||
|
||||
Verifies:
|
||||
- Special characters are handled in value
|
||||
- LIKE expression is created correctly
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "title"
|
||||
value = "C++ & Python's features"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_zero_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test zero value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Zero is treated as valid value
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "price"
|
||||
value = 0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_negative_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test negative value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Negative numbers are handled correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "temperature"
|
||||
value = -10.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_float_value_with_integer_comparison(self, retrieval):
|
||||
"""
|
||||
Test float value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Float values work correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
@ -1,113 +0,0 @@
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class TestRetrievalService:
|
||||
@pytest.fixture
|
||||
def mock_dataset(self) -> Dataset:
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = str(uuid4())
|
||||
dataset.tenant_id = str(uuid4())
|
||||
dataset.name = "test_dataset"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.provider = "dify"
|
||||
return dataset
|
||||
|
||||
def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset):
|
||||
"""
|
||||
Repro test for current bug:
|
||||
reranking runs after `with flask_app.app_context():` exits.
|
||||
`_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`,
|
||||
so we must assert from that list (not from an outer try/except).
|
||||
"""
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
flask_app = Flask(__name__)
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# second dataset to ensure dataset_count > 1 reranking branch
|
||||
secondary_dataset = Mock(spec=Dataset)
|
||||
secondary_dataset.id = str(uuid4())
|
||||
secondary_dataset.provider = "dify"
|
||||
secondary_dataset.indexing_technique = "high_quality"
|
||||
|
||||
# retriever returns 1 doc into internal list (all_documents_item)
|
||||
document = Document(
|
||||
page_content="Context aware doc",
|
||||
metadata={
|
||||
"doc_id": "doc1",
|
||||
"score": 0.95,
|
||||
"document_id": str(uuid4()),
|
||||
"dataset_id": mock_dataset.id,
|
||||
},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
def fake_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.append(document)
|
||||
|
||||
called = {"init": 0, "invoke": 0}
|
||||
|
||||
class ContextRequiredPostProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
called["init"] += 1
|
||||
# will raise RuntimeError if no Flask app context exists
|
||||
_ = current_app.name
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
called["invoke"] += 1
|
||||
_ = current_app.name
|
||||
return kwargs.get("documents") or args[1]
|
||||
|
||||
# output list from _multiple_retrieve_thread
|
||||
all_documents: list[Document] = []
|
||||
|
||||
# IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here
|
||||
thread_exceptions: list[Exception] = []
|
||||
|
||||
def target():
|
||||
with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever):
|
||||
with patch(
|
||||
"core.rag.retrieval.dataset_retrieval.DataPostProcessor",
|
||||
ContextRequiredPostProcessor,
|
||||
):
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=flask_app,
|
||||
available_datasets=[mock_dataset, secondary_dataset],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={
|
||||
"reranking_provider_name": "cohere",
|
||||
"reranking_model_name": "rerank-v2",
|
||||
},
|
||||
weights=None,
|
||||
top_k=3,
|
||||
score_threshold=0.0,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=2, # force reranking branch
|
||||
thread_exceptions=thread_exceptions, # ✅ key
|
||||
)
|
||||
|
||||
t = threading.Thread(target=target)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
# Ensure reranking branch was actually executed
|
||||
assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run."
|
||||
|
||||
# Current buggy code should record an exception (not raise it)
|
||||
assert not thread_exceptions, thread_exceptions
|
||||
@ -0,0 +1,100 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class TestFunctionCallMultiDatasetRouter:
|
||||
def test_invoke_returns_none_when_no_tools(self) -> None:
|
||||
router = FunctionCallMultiDatasetRouter()
|
||||
|
||||
dataset_id, usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[],
|
||||
model_config=Mock(),
|
||||
model_instance=Mock(),
|
||||
)
|
||||
|
||||
assert dataset_id is None
|
||||
assert usage == LLMUsage.empty_usage()
|
||||
|
||||
def test_invoke_returns_single_tool_directly(self) -> None:
|
||||
router = FunctionCallMultiDatasetRouter()
|
||||
tool = Mock()
|
||||
tool.name = "dataset-1"
|
||||
|
||||
dataset_id, usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[tool],
|
||||
model_config=Mock(),
|
||||
model_instance=Mock(),
|
||||
)
|
||||
|
||||
assert dataset_id == "dataset-1"
|
||||
assert usage == LLMUsage.empty_usage()
|
||||
|
||||
def test_invoke_returns_tool_from_model_response(self) -> None:
|
||||
router = FunctionCallMultiDatasetRouter()
|
||||
tool_1 = Mock()
|
||||
tool_1.name = "dataset-1"
|
||||
tool_2 = Mock()
|
||||
tool_2.name = "dataset-2"
|
||||
usage = LLMUsage.empty_usage()
|
||||
response = Mock()
|
||||
response.usage = usage
|
||||
response.message.tool_calls = [Mock(function=Mock())]
|
||||
response.message.tool_calls[0].function.name = "dataset-2"
|
||||
model_instance = Mock()
|
||||
model_instance.invoke_llm.return_value = response
|
||||
|
||||
dataset_id, returned_usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[tool_1, tool_2],
|
||||
model_config=Mock(),
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
assert dataset_id == "dataset-2"
|
||||
assert returned_usage == usage
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
|
||||
def test_invoke_returns_none_when_no_tool_calls(self) -> None:
|
||||
router = FunctionCallMultiDatasetRouter()
|
||||
response = Mock()
|
||||
response.usage = LLMUsage.empty_usage()
|
||||
response.message.tool_calls = []
|
||||
model_instance = Mock()
|
||||
model_instance.invoke_llm.return_value = response
|
||||
tool_1 = Mock()
|
||||
tool_1.name = "dataset-1"
|
||||
tool_2 = Mock()
|
||||
tool_2.name = "dataset-2"
|
||||
|
||||
dataset_id, usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[tool_1, tool_2],
|
||||
model_config=Mock(),
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
assert dataset_id is None
|
||||
assert usage == response.usage
|
||||
|
||||
def test_invoke_returns_empty_usage_when_model_raises(self) -> None:
|
||||
router = FunctionCallMultiDatasetRouter()
|
||||
model_instance = Mock()
|
||||
model_instance.invoke_llm.side_effect = RuntimeError("boom")
|
||||
tool_1 = Mock()
|
||||
tool_1.name = "dataset-1"
|
||||
tool_2 = Mock()
|
||||
tool_2.name = "dataset-2"
|
||||
|
||||
dataset_id, usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[tool_1, tool_2],
|
||||
model_config=Mock(),
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
assert dataset_id is None
|
||||
assert usage == LLMUsage.empty_usage()
|
||||
@ -0,0 +1,252 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class TestReactMultiDatasetRouter:
|
||||
def test_invoke_returns_none_when_no_tools(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
|
||||
dataset_id, usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[],
|
||||
model_config=Mock(),
|
||||
model_instance=Mock(),
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
assert dataset_id is None
|
||||
assert usage == LLMUsage.empty_usage()
|
||||
|
||||
def test_invoke_returns_single_tool_directly(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
tool = Mock()
|
||||
tool.name = "dataset-1"
|
||||
|
||||
dataset_id, usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[tool],
|
||||
model_config=Mock(),
|
||||
model_instance=Mock(),
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
assert dataset_id == "dataset-1"
|
||||
assert usage == LLMUsage.empty_usage()
|
||||
|
||||
def test_invoke_returns_tool_from_react_invoke(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
usage = LLMUsage.empty_usage()
|
||||
tool_1 = Mock(name="dataset-1")
|
||||
tool_1.name = "dataset-1"
|
||||
tool_2 = Mock(name="dataset-2")
|
||||
tool_2.name = "dataset-2"
|
||||
|
||||
with patch.object(router, "_react_invoke", return_value=("dataset-2", usage)) as mock_react:
|
||||
dataset_id, returned_usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[tool_1, tool_2],
|
||||
model_config=Mock(),
|
||||
model_instance=Mock(),
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
mock_react.assert_called_once()
|
||||
assert dataset_id == "dataset-2"
|
||||
assert returned_usage == usage
|
||||
|
||||
def test_invoke_handles_react_invoke_errors(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
tool_1 = Mock()
|
||||
tool_1.name = "dataset-1"
|
||||
tool_2 = Mock()
|
||||
tool_2.name = "dataset-2"
|
||||
|
||||
with patch.object(router, "_react_invoke", side_effect=RuntimeError("boom")):
|
||||
dataset_id, usage = router.invoke(
|
||||
query="python",
|
||||
dataset_tools=[tool_1, tool_2],
|
||||
model_config=Mock(),
|
||||
model_instance=Mock(),
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
assert dataset_id is None
|
||||
assert usage == LLMUsage.empty_usage()
|
||||
|
||||
def test_react_invoke_returns_action_tool(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
model_config = Mock()
|
||||
model_config.mode = "chat"
|
||||
model_config.parameters = {"temperature": 0.1}
|
||||
usage = LLMUsage.empty_usage()
|
||||
tools = [Mock(name="dataset-1"), Mock(name="dataset-2")]
|
||||
tools[0].name = "dataset-1"
|
||||
tools[0].description = "desc"
|
||||
tools[1].name = "dataset-2"
|
||||
tools[1].description = "desc"
|
||||
|
||||
with (
|
||||
patch.object(router, "create_chat_prompt", return_value=[Mock()]) as mock_chat_prompt,
|
||||
patch(
|
||||
"core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform"
|
||||
) as mock_prompt_transform,
|
||||
patch.object(router, "_invoke_llm", return_value=('{"action":"dataset-2","action_input":{}}', usage)),
|
||||
patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls,
|
||||
):
|
||||
mock_prompt_transform.return_value.get_prompt.return_value = [Mock()]
|
||||
mock_parser_cls.return_value.parse.return_value = ReactAction("dataset-2", {}, "log")
|
||||
|
||||
dataset_id, returned_usage = router._react_invoke(
|
||||
query="python",
|
||||
model_config=model_config,
|
||||
model_instance=Mock(),
|
||||
tools=tools,
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
mock_chat_prompt.assert_called_once()
|
||||
assert dataset_id == "dataset-2"
|
||||
assert returned_usage == usage
|
||||
|
||||
def test_react_invoke_returns_none_for_finish(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
model_config = Mock()
|
||||
model_config.mode = "completion"
|
||||
model_config.parameters = {"temperature": 0.1}
|
||||
usage = LLMUsage.empty_usage()
|
||||
tool = Mock()
|
||||
tool.name = "dataset-1"
|
||||
tool.description = "desc"
|
||||
|
||||
with (
|
||||
patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt,
|
||||
patch(
|
||||
"core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform"
|
||||
) as mock_prompt_transform,
|
||||
patch.object(
|
||||
router, "_invoke_llm", return_value=('{"action":"Final Answer","action_input":"done"}', usage)
|
||||
),
|
||||
patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls,
|
||||
):
|
||||
mock_prompt_transform.return_value.get_prompt.return_value = [Mock()]
|
||||
mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log")
|
||||
|
||||
dataset_id, returned_usage = router._react_invoke(
|
||||
query="python",
|
||||
model_config=model_config,
|
||||
model_instance=Mock(),
|
||||
tools=[tool],
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
mock_completion_prompt.assert_called_once()
|
||||
assert dataset_id is None
|
||||
assert returned_usage == usage
|
||||
|
||||
def test_invoke_llm_and_handle_result(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
usage = LLMUsage.empty_usage()
|
||||
delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=usage)
|
||||
chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta)
|
||||
model_instance = Mock()
|
||||
model_instance.invoke_llm.return_value = iter([chunk])
|
||||
|
||||
with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct:
|
||||
text, returned_usage = router._invoke_llm(
|
||||
completion_param={"temperature": 0.1},
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[Mock()],
|
||||
stop=["Observation:"],
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
assert text == "part"
|
||||
assert returned_usage == usage
|
||||
mock_deduct.assert_called_once()
|
||||
|
||||
def test_handle_invoke_result_with_empty_usage(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=None)
|
||||
chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta)
|
||||
|
||||
text, usage = router._handle_invoke_result(iter([chunk]))
|
||||
|
||||
assert text == "part"
|
||||
assert usage == LLMUsage.empty_usage()
|
||||
|
||||
def test_create_chat_prompt(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
tool_1 = Mock()
|
||||
tool_1.name = "dataset-1"
|
||||
tool_1.description = "d1"
|
||||
tool_2 = Mock()
|
||||
tool_2.name = "dataset-2"
|
||||
tool_2.description = "d2"
|
||||
|
||||
chat_prompt = router.create_chat_prompt(query="python", tools=[tool_1, tool_2])
|
||||
assert len(chat_prompt) == 2
|
||||
assert chat_prompt[0].role == PromptMessageRole.SYSTEM
|
||||
assert chat_prompt[1].role == PromptMessageRole.USER
|
||||
assert "dataset-1" in chat_prompt[0].text
|
||||
assert "dataset-2" in chat_prompt[0].text
|
||||
|
||||
def test_create_completion_prompt(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
tool_1 = Mock()
|
||||
tool_1.name = "dataset-1"
|
||||
tool_1.description = "d1"
|
||||
tool_2 = Mock()
|
||||
tool_2.name = "dataset-2"
|
||||
tool_2.description = "d2"
|
||||
|
||||
completion_prompt = router.create_completion_prompt(tools=[tool_1, tool_2])
|
||||
assert "dataset-1: d1" in completion_prompt.text
|
||||
assert "dataset-2: d2" in completion_prompt.text
|
||||
|
||||
def test_react_invoke_uses_completion_branch_for_non_chat_mode(self) -> None:
|
||||
router = ReactMultiDatasetRouter()
|
||||
model_config = Mock()
|
||||
model_config.mode = "unknown-mode"
|
||||
model_config.parameters = {}
|
||||
tool = Mock()
|
||||
tool.name = "dataset-1"
|
||||
tool.description = "desc"
|
||||
|
||||
with (
|
||||
patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt,
|
||||
patch(
|
||||
"core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform"
|
||||
) as mock_prompt_transform,
|
||||
patch.object(
|
||||
router,
|
||||
"_invoke_llm",
|
||||
return_value=('{"action":"Final Answer","action_input":"done"}', LLMUsage.empty_usage()),
|
||||
),
|
||||
patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls,
|
||||
):
|
||||
mock_prompt_transform.return_value.get_prompt.return_value = [Mock()]
|
||||
mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log")
|
||||
dataset_id, usage = router._react_invoke(
|
||||
query="python",
|
||||
model_config=model_config,
|
||||
model_instance=Mock(),
|
||||
tools=[tool],
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
mock_completion_prompt.assert_called_once()
|
||||
assert dataset_id is None
|
||||
assert usage == LLMUsage.empty_usage()
|
||||
@ -0,0 +1,69 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
|
||||
|
||||
class TestStructuredChatOutputParser:
|
||||
def test_parse_action_without_action_input(self) -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
text = 'Action:\n```json\n{"action":"some_action"}\n```'
|
||||
result = parser.parse(text)
|
||||
|
||||
assert isinstance(result, ReactAction)
|
||||
assert result.tool == "some_action"
|
||||
assert result.tool_input == {}
|
||||
|
||||
def test_parse_json_without_action_key(self) -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
text = 'Action:\n```json\n{"not_action":"search"}\n```'
|
||||
with pytest.raises(ValueError, match="Could not parse LLM output"):
|
||||
parser.parse(text)
|
||||
|
||||
def test_parse_returns_action_for_tool_call(self) -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
text = (
|
||||
'Thought: call tool\nAction:\n```json\n{"action":"search_dataset","action_input":{"query":"python"}}\n```'
|
||||
)
|
||||
|
||||
result = parser.parse(text)
|
||||
|
||||
assert isinstance(result, ReactAction)
|
||||
assert result.tool == "search_dataset"
|
||||
assert result.tool_input == {"query": "python"}
|
||||
assert result.log == text
|
||||
|
||||
def test_parse_returns_finish_for_final_answer(self) -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
text = 'Thought: done\nAction:\n```json\n{"action":"Final Answer","action_input":"final text"}\n```'
|
||||
|
||||
result = parser.parse(text)
|
||||
|
||||
assert isinstance(result, ReactFinish)
|
||||
assert result.return_values == {"output": "final text"}
|
||||
assert result.log == text
|
||||
|
||||
def test_parse_returns_finish_for_json_array_payload(self) -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
text = 'Action:\n```json\n[{"action":"search","action_input":"hello"}]\n```'
|
||||
result = parser.parse(text)
|
||||
|
||||
assert isinstance(result, ReactFinish)
|
||||
assert result.return_values == {"output": text}
|
||||
assert result.log == text
|
||||
|
||||
def test_parse_returns_finish_for_plain_text(self) -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
text = "No structured action block"
|
||||
|
||||
result = parser.parse(text)
|
||||
|
||||
assert isinstance(result, ReactFinish)
|
||||
assert result.return_values == {"output": text}
|
||||
|
||||
def test_parse_raises_value_error_for_invalid_json(self) -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
text = 'Action:\n```json\n{"action":"search","action_input": }\n```'
|
||||
|
||||
with pytest.raises(ValueError, match="Could not parse LLM output"):
|
||||
parser.parse(text)
|
||||
@ -125,7 +125,11 @@ Run with coverage:
|
||||
- Tests are organized by functionality in classes for better organization
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import string
|
||||
import sys
|
||||
import types
|
||||
from inspect import currentframe
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -604,6 +608,51 @@ class TestRecursiveCharacterTextSplitter:
|
||||
assert "def hello_world" in combined or "hello_world" in combined
|
||||
|
||||
|
||||
class TestTextSplitterBasePaths:
|
||||
"""Target uncovered base TextSplitter paths."""
|
||||
|
||||
def test_from_huggingface_tokenizer_success_path(self):
|
||||
"""Cover from_huggingface_tokenizer success branch with mocked transformers."""
|
||||
|
||||
class _FakePreTrainedTokenizerBase:
|
||||
pass
|
||||
|
||||
class _FakeTokenizer(_FakePreTrainedTokenizerBase):
|
||||
def encode(self, text: str):
|
||||
return [ord(c) for c in text]
|
||||
|
||||
fake_transformers = types.SimpleNamespace(PreTrainedTokenizerBase=_FakePreTrainedTokenizerBase)
|
||||
with patch.dict(sys.modules, {"transformers": fake_transformers}):
|
||||
splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
||||
tokenizer=_FakeTokenizer(),
|
||||
chunk_size=5,
|
||||
chunk_overlap=1,
|
||||
)
|
||||
|
||||
chunks = splitter.split_text("abcdef")
|
||||
assert chunks
|
||||
|
||||
def test_from_huggingface_tokenizer_import_error(self):
|
||||
"""Cover from_huggingface_tokenizer import-error branch."""
|
||||
with patch.dict(sys.modules, {"transformers": None}):
|
||||
with pytest.raises(ValueError, match="Could not import transformers"):
|
||||
RecursiveCharacterTextSplitter.from_huggingface_tokenizer(tokenizer=object(), chunk_size=5)
|
||||
|
||||
def test_atransform_documents_raises_not_implemented(self):
|
||||
"""Cover atransform_documents NotImplemented branch."""
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size=20, chunk_overlap=5)
|
||||
with pytest.raises(NotImplementedError):
|
||||
asyncio.run(splitter.atransform_documents([Document(page_content="x", metadata={})]))
|
||||
|
||||
def test_merge_splits_logs_warning_for_oversized_total(self):
|
||||
"""Cover logger.warning path in _merge_splits."""
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size=5, chunk_overlap=1)
|
||||
with patch("core.rag.splitter.text_splitter.logger.warning") as mock_warning:
|
||||
merged = splitter._merge_splits(["abcdefghij", "b"], "", [10, 1])
|
||||
assert merged
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test TokenTextSplitter
|
||||
# ============================================================================
|
||||
@ -662,6 +711,44 @@ class TestTokenTextSplitter:
|
||||
except ImportError:
|
||||
pytest.skip("tiktoken not installed")
|
||||
|
||||
def test_initialization_and_split_with_mocked_tiktoken_encoding(self):
|
||||
"""Cover TokenTextSplitter __init__ else-path and split_text logic."""
|
||||
|
||||
class _FakeEncoding:
|
||||
def encode(self, text: str, allowed_special=None, disallowed_special=None):
|
||||
return [ord(c) for c in text]
|
||||
|
||||
def decode(self, token_ids: list[int]) -> str:
|
||||
return "".join(chr(i) for i in token_ids)
|
||||
|
||||
fake_tiktoken = types.SimpleNamespace(get_encoding=lambda name: _FakeEncoding())
|
||||
with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}):
|
||||
splitter = TokenTextSplitter(encoding_name="gpt2", chunk_size=4, chunk_overlap=1)
|
||||
result = splitter.split_text("abcdefgh")
|
||||
|
||||
assert result
|
||||
assert all(isinstance(chunk, str) for chunk in result)
|
||||
|
||||
def test_initialization_with_model_name_uses_encoding_for_model(self):
|
||||
"""Cover TokenTextSplitter model_name init branch."""
|
||||
|
||||
class _FakeEncoding:
|
||||
def encode(self, text: str, allowed_special=None, disallowed_special=None):
|
||||
return [ord(c) for c in text]
|
||||
|
||||
def decode(self, token_ids: list[int]) -> str:
|
||||
return "".join(chr(i) for i in token_ids)
|
||||
|
||||
fake_encoding = _FakeEncoding()
|
||||
fake_tiktoken = types.SimpleNamespace(
|
||||
encoding_for_model=lambda model_name: fake_encoding,
|
||||
get_encoding=lambda name: _FakeEncoding(),
|
||||
)
|
||||
with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}):
|
||||
splitter = TokenTextSplitter(model_name="gpt-4", chunk_size=5, chunk_overlap=1)
|
||||
|
||||
assert splitter._tokenizer is fake_encoding
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test EnhanceRecursiveCharacterTextSplitter
|
||||
@ -731,6 +818,50 @@ class TestEnhanceRecursiveCharacterTextSplitter:
|
||||
assert len(result) > 0
|
||||
assert all(isinstance(chunk, str) for chunk in result)
|
||||
|
||||
def test_from_encoder_internal_token_encoder_paths(self):
|
||||
"""
|
||||
Test internal _token_encoder branches by capturing local closure from frame.
|
||||
|
||||
This validates:
|
||||
- empty texts path
|
||||
- embedding model path
|
||||
- GPT2Tokenizer fallback path
|
||||
- _character_encoder empty-path branch
|
||||
"""
|
||||
|
||||
class _SpySplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||
captured_token_encoder = None
|
||||
captured_character_encoder = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
frame = currentframe()
|
||||
if frame and frame.f_back:
|
||||
_SpySplitter.captured_token_encoder = frame.f_back.f_locals.get("_token_encoder")
|
||||
_SpySplitter.captured_character_encoder = frame.f_back.f_locals.get("_character_encoder")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.get_text_embedding_num_tokens.return_value = [3, 5]
|
||||
|
||||
_SpySplitter.from_encoder(embedding_model_instance=mock_model, chunk_size=10, chunk_overlap=1)
|
||||
token_encoder = _SpySplitter.captured_token_encoder
|
||||
character_encoder = _SpySplitter.captured_character_encoder
|
||||
|
||||
assert token_encoder is not None
|
||||
assert character_encoder is not None
|
||||
assert token_encoder([]) == []
|
||||
assert token_encoder(["abc", "defgh"]) == [3, 5]
|
||||
assert character_encoder([]) == []
|
||||
|
||||
with patch(
|
||||
"core.rag.splitter.fixed_text_splitter.GPT2Tokenizer.get_num_tokens",
|
||||
side_effect=lambda text: len(text) + 1,
|
||||
):
|
||||
_SpySplitter.from_encoder(embedding_model_instance=None, chunk_size=10, chunk_overlap=1)
|
||||
token_encoder_without_model = _SpySplitter.captured_token_encoder
|
||||
assert token_encoder_without_model is not None
|
||||
assert token_encoder_without_model(["ab", "cdef"]) == [3, 5]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test FixedRecursiveCharacterTextSplitter
|
||||
@ -908,6 +1039,56 @@ class TestFixedRecursiveCharacterTextSplitter:
|
||||
chunks = splitter.split_text(data)
|
||||
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
|
||||
|
||||
def test_recursive_split_keep_separator_and_recursive_fallback(self):
|
||||
"""Cover keep-separator split branch and recursive _split_text fallback."""
|
||||
text = "short." + ("x" * 60)
|
||||
splitter = FixedRecursiveCharacterTextSplitter(
|
||||
fixed_separator="",
|
||||
separators=[".", " ", ""],
|
||||
chunk_size=10,
|
||||
chunk_overlap=2,
|
||||
keep_separator=True,
|
||||
)
|
||||
|
||||
chunks = splitter.recursive_split_text(text)
|
||||
|
||||
assert chunks
|
||||
assert any("short." in chunk for chunk in chunks)
|
||||
assert any(len(chunk) <= 12 for chunk in chunks)
|
||||
|
||||
def test_recursive_split_newline_separator_filtering(self):
|
||||
"""Cover newline-specific empty filtering branch."""
|
||||
text = "line1\n\nline2\n\nline3"
|
||||
splitter = FixedRecursiveCharacterTextSplitter(
|
||||
fixed_separator="",
|
||||
separators=["\n", ""],
|
||||
chunk_size=50,
|
||||
chunk_overlap=5,
|
||||
)
|
||||
|
||||
chunks = splitter.recursive_split_text(text)
|
||||
|
||||
assert chunks
|
||||
assert all(chunk != "" for chunk in chunks)
|
||||
assert "line1" in "".join(chunks)
|
||||
assert "line2" in "".join(chunks)
|
||||
assert "line3" in "".join(chunks)
|
||||
|
||||
def test_recursive_split_without_new_separator_appends_long_chunk(self):
|
||||
"""Cover branch where no further separators exist and long split is appended directly."""
|
||||
text = "aa\n" + ("b" * 40)
|
||||
splitter = FixedRecursiveCharacterTextSplitter(
|
||||
fixed_separator="",
|
||||
separators=["\n"],
|
||||
chunk_size=10,
|
||||
chunk_overlap=2,
|
||||
)
|
||||
|
||||
chunks = splitter.recursive_split_text(text)
|
||||
|
||||
assert "aa" in chunks
|
||||
assert any(len(chunk) >= 40 for chunk in chunks)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Metadata Preservation
|
||||
|
||||
Loading…
Reference in New Issue
Block a user