test: unit test cases for rag.cleaner, rag.data_post_processor and rag.datasource (#32521)

This commit is contained in:
Rajat Agarwal 2026-03-24 23:49:15 +05:30 committed by GitHub
parent 36cc1bf025
commit 6f137fdb00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 13766 additions and 47 deletions

View File

@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector):
) )
) )
score_threshold = float(kwargs.get("score_threshold") or 0.0)
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score >= score_threshold: if score >= score_threshold:
if doc.metadata is not None: if doc.metadata is not None:
doc.metadata["score"] = score doc.metadata["score"] = score
docs.append(doc) docs.append(doc)
return docs return docs

View File

@ -211,3 +211,16 @@ class TestCleanProcessor:
text = "[Text with (parens) and symbols](https://example.com)" text = "[Text with (parens) and symbols](https://example.com)"
expected = "[Text with (parens) and symbols](https://example.com)" expected = "[Text with (parens) and symbols](https://example.com)"
assert CleanProcessor.clean(text, process_rule) == expected assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_remove_urls_emails_preserves_markdown_image_links(self):
"""Remove plain URLs and emails while preserving markdown image links."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
text = "Email test@example.com and remove https://remove.com but keep ![diagram](https://example.com/image.png)"
result = CleanProcessor.clean(text, process_rule)
assert result == "Email and remove but keep ![diagram](https://example.com/image.png)"
def test_filter_string_returns_input_text(self):
"""Test filter_string passthrough behavior."""
processor = CleanProcessor()
assert processor.filter_string("raw text") == "raw text"

View File

@ -0,0 +1,249 @@
from unittest.mock import MagicMock, patch
import pytest
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
def _doc(content: str) -> Document:
return Document(page_content=content)
class TestDataPostProcessor:
def test_init_sets_rerank_and_reorder_runners(self):
rerank_runner = object()
reorder_runner = object()
with patch.object(DataPostProcessor, "_get_rerank_runner", return_value=rerank_runner) as rerank_mock:
with patch.object(DataPostProcessor, "_get_reorder_runner", return_value=reorder_runner) as reorder_mock:
processor = DataPostProcessor(
tenant_id="tenant-1",
reranking_mode=RerankMode.WEIGHTED_SCORE,
reranking_model={"config": "value"},
weights={"weight": "value"},
reorder_enabled=True,
)
assert processor.rerank_runner is rerank_runner
assert processor.reorder_runner is reorder_runner
rerank_mock.assert_called_once_with(
RerankMode.WEIGHTED_SCORE,
"tenant-1",
{"config": "value"},
{"weight": "value"},
)
reorder_mock.assert_called_once_with(True)
def test_invoke_applies_rerank_then_reorder(self):
original_documents = [_doc("doc-a")]
reranked_documents = [_doc("doc-b")]
reordered_documents = [_doc("doc-c")]
processor = DataPostProcessor.__new__(DataPostProcessor)
processor.rerank_runner = MagicMock()
processor.rerank_runner.run.return_value = reranked_documents
processor.reorder_runner = MagicMock()
processor.reorder_runner.run.return_value = reordered_documents
result = processor.invoke(
query="how to test",
documents=original_documents,
score_threshold=0.3,
top_n=2,
user="user-1",
query_type=QueryType.IMAGE_QUERY,
)
processor.rerank_runner.run.assert_called_once_with(
"how to test",
original_documents,
0.3,
2,
"user-1",
QueryType.IMAGE_QUERY,
)
processor.reorder_runner.run.assert_called_once_with(reranked_documents)
assert result == reordered_documents
def test_invoke_returns_original_documents_when_no_runner_is_configured(self):
documents = [_doc("doc-a"), _doc("doc-b")]
processor = DataPostProcessor.__new__(DataPostProcessor)
processor.rerank_runner = None
processor.reorder_runner = None
assert processor.invoke(query="query", documents=documents) == documents
def test_get_rerank_runner_for_weighted_score(self):
weights_config = {
"vector_setting": {
"vector_weight": 0.7,
"embedding_provider_name": "provider-x",
"embedding_model_name": "embedding-y",
},
"keyword_setting": {"keyword_weight": 0.3},
}
expected_runner = object()
processor = DataPostProcessor.__new__(DataPostProcessor)
with patch(
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner",
return_value=expected_runner,
) as factory_mock:
result = processor._get_rerank_runner(
reranking_mode=RerankMode.WEIGHTED_SCORE,
tenant_id="tenant-1",
reranking_model=None,
weights=weights_config,
)
assert result is expected_runner
kwargs = factory_mock.call_args.kwargs
assert kwargs["runner_type"] == RerankMode.WEIGHTED_SCORE
assert kwargs["tenant_id"] == "tenant-1"
assert kwargs["weights"].vector_setting.vector_weight == 0.7
assert kwargs["weights"].vector_setting.embedding_provider_name == "provider-x"
assert kwargs["weights"].vector_setting.embedding_model_name == "embedding-y"
assert kwargs["weights"].keyword_setting.keyword_weight == 0.3
def test_get_rerank_runner_for_reranking_model_returns_none_without_model_instance(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
reranking_model = {
"reranking_provider_name": "provider-x",
"reranking_model_name": "model-y",
}
with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=None) as model_mock:
with patch(
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner"
) as factory_mock:
result = processor._get_rerank_runner(
reranking_mode=RerankMode.RERANKING_MODEL,
tenant_id="tenant-1",
reranking_model=reranking_model,
weights=None,
)
assert result is None
model_mock.assert_called_once_with("tenant-1", reranking_model)
factory_mock.assert_not_called()
def test_get_rerank_runner_for_reranking_model_creates_runner_with_model_instance(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
model_instance = object()
expected_runner = object()
with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=model_instance):
with patch(
"core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner",
return_value=expected_runner,
) as factory_mock:
result = processor._get_rerank_runner(
reranking_mode=RerankMode.RERANKING_MODEL,
tenant_id="tenant-1",
reranking_model={
"reranking_provider_name": "provider-x",
"reranking_model_name": "model-y",
},
weights=None,
)
assert result is expected_runner
factory_mock.assert_called_once_with(
runner_type=RerankMode.RERANKING_MODEL,
rerank_model_instance=model_instance,
)
def test_get_rerank_runner_returns_none_for_unsupported_mode(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
assert processor._get_rerank_runner("unsupported", "tenant-1", None, None) is None
assert processor._get_rerank_runner(RerankMode.WEIGHTED_SCORE, "tenant-1", None, None) is None
def test_get_reorder_runner_by_flag(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
assert isinstance(processor._get_reorder_runner(True), ReorderRunner)
assert processor._get_reorder_runner(False) is None
def test_get_rerank_model_instance_returns_none_when_config_is_missing(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
assert processor._get_rerank_model_instance("tenant-1", None) is None
def test_get_rerank_model_instance_raises_key_error_for_incomplete_config(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
manager_instance = manager_cls.return_value
with pytest.raises(KeyError, match="reranking_model_name"):
processor._get_rerank_model_instance(
tenant_id="tenant-1",
reranking_model={"reranking_provider_name": "provider-x"},
)
manager_instance.get_model_instance.assert_not_called()
def test_get_rerank_model_instance_success(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
model_instance = object()
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
manager_instance = manager_cls.return_value
manager_instance.get_model_instance.return_value = model_instance
result = processor._get_rerank_model_instance(
tenant_id="tenant-1",
reranking_model={
"reranking_provider_name": "provider-x",
"reranking_model_name": "reranker-1",
},
)
assert result is model_instance
manager_instance.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="provider-x",
model_type=ModelType.RERANK,
model="reranker-1",
)
def test_get_rerank_model_instance_handles_authorization_error(self):
processor = DataPostProcessor.__new__(DataPostProcessor)
with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls:
manager_instance = manager_cls.return_value
manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized")
result = processor._get_rerank_model_instance(
tenant_id="tenant-1",
reranking_model={
"reranking_provider_name": "provider-x",
"reranking_model_name": "reranker-1",
},
)
assert result is None
class TestReorderRunner:
def test_run_reorders_even_sized_document_list(self):
documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4"), _doc("5")]
reordered = ReorderRunner().run(documents)
assert [document.page_content for document in reordered] == ["0", "2", "4", "5", "3", "1"]
def test_run_handles_odd_sized_and_empty_document_lists(self):
odd_documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4")]
runner = ReorderRunner()
odd_reordered = runner.run(odd_documents)
assert [document.page_content for document in odd_reordered] == ["0", "2", "4", "3", "1"]
assert runner.run([]) == []

View File

@ -0,0 +1,414 @@
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import core.rag.datasource.keyword.jieba.jieba as jieba_module
from core.rag.datasource.keyword.jieba.jieba import Jieba, dumps_with_sets, set_orjson_default
from core.rag.models.document import Document
class _DummyLock:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class _Field:
def __init__(self, name: str):
self._name = name
def __eq__(self, other):
return ("eq", self._name, other)
def in_(self, values):
return ("in", self._name, tuple(values))
class _FakeQuery:
def __init__(self):
self.where_calls: list[tuple] = []
def where(self, *conditions):
self.where_calls.append(conditions)
return self
class _FakeExecuteResult:
def __init__(self, segments: list[SimpleNamespace]):
self._segments = segments
def scalars(self):
return self
def all(self):
return self._segments
class _FakeSelect:
def __init__(self):
self.where_conditions: tuple | None = None
def where(self, *conditions):
self.where_conditions = conditions
return self
def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None):
return SimpleNamespace(
data_source_type=data_source_type,
keyword_table_dict=keyword_table_dict,
keyword_table="",
)
def _dataset(dataset_keyword_table=None, keyword_number=None):
return SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
keyword_number=keyword_number,
dataset_keyword_table=dataset_keyword_table,
)
@pytest.fixture
def patched_runtime(monkeypatch):
session = MagicMock()
db = SimpleNamespace(session=session)
storage = MagicMock()
lock = MagicMock(return_value=_DummyLock())
redis_client = SimpleNamespace(lock=lock)
monkeypatch.setattr(jieba_module, "db", db)
monkeypatch.setattr(jieba_module, "storage", storage)
monkeypatch.setattr(jieba_module, "redis_client", redis_client)
return SimpleNamespace(session=session, storage=storage, lock=lock)
def test_create_indexes_documents_and_returns_self(monkeypatch, patched_runtime):
dataset = _dataset(_dataset_keyword_table(), keyword_number=2)
keyword = Jieba(dataset)
handler = MagicMock()
handler.extract_keywords.return_value = {"kw1", "kw2"}
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
result = keyword.create(
[
Document(page_content="alpha", metadata={"doc_id": "node-1"}),
SimpleNamespace(page_content="ignored", metadata=None),
]
)
assert result is keyword
keyword._update_segment_keywords.assert_called_once()
call_args = keyword._update_segment_keywords.call_args.args
assert call_args[0] == "dataset-1"
assert call_args[1] == "node-1"
assert set(call_args[2]) == {"kw1", "kw2"}
saved_table = keyword._save_dataset_keyword_table.call_args.args[0]
assert saved_table["kw1"] == {"node-1"}
assert saved_table["kw2"] == {"node-1"}
patched_runtime.lock.assert_called_once_with("keyword_indexing_lock_dataset-1", timeout=600)
def test_add_texts_supports_keywords_list_and_extract_fallback(monkeypatch, patched_runtime):
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=3))
handler = MagicMock()
handler.extract_keywords.return_value = {"auto"}
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
texts = [
Document(page_content="extract-this", metadata={"doc_id": "node-1"}),
Document(page_content="use-manual", metadata={"doc_id": "node-2"}),
]
keyword.add_texts(texts, keywords_list=[[], ["manual"]])
assert keyword._update_segment_keywords.call_count == 2
first_call = keyword._update_segment_keywords.call_args_list[0].args
second_call = keyword._update_segment_keywords.call_args_list[1].args
assert set(first_call[2]) == {"auto"}
assert second_call[2] == ["manual"]
keyword._save_dataset_keyword_table.assert_called_once()
def test_add_texts_without_keywords_list_always_uses_extractor(monkeypatch, patched_runtime):
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=1))
handler = MagicMock()
handler.extract_keywords.return_value = {"from-extractor"}
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
keyword.add_texts([Document(page_content="content", metadata={"doc_id": "node-1"})])
handler.extract_keywords.assert_called_once_with("content", 1)
assert set(keyword._update_segment_keywords.call_args.args[2]) == {"from-extractor"}
def test_text_exists_handles_missing_and_existing_keyword_table(monkeypatch):
keyword = Jieba(_dataset(_dataset_keyword_table()))
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None))
assert keyword.text_exists("node-1") is False
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
assert keyword.text_exists("node-2") is True
assert keyword.text_exists("node-x") is False
def test_delete_by_ids_updates_table_when_present(monkeypatch, patched_runtime):
keyword = Jieba(_dataset(_dataset_keyword_table()))
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock(return_value={"k": {"node-2"}}))
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
keyword.delete_by_ids(["node-1"])
keyword._delete_ids_from_keyword_table.assert_called_once_with({"k": {"node-1", "node-2"}}, ["node-1"])
keyword._save_dataset_keyword_table.assert_called_once_with({"k": {"node-2"}})
def test_delete_by_ids_saves_none_when_keyword_table_is_missing(monkeypatch, patched_runtime):
keyword = Jieba(_dataset(_dataset_keyword_table()))
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None))
monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
keyword.delete_by_ids(["node-1"])
keyword._delete_ids_from_keyword_table.assert_not_called()
keyword._save_dataset_keyword_table.assert_called_once_with(None)
def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, patched_runtime):
class _FakeDocumentSegment:
dataset_id = _Field("dataset_id")
index_node_id = _Field("index_node_id")
document_id = _Field("document_id")
keyword = Jieba(_dataset(_dataset_keyword_table()))
query_stmt = _FakeQuery()
patched_runtime.session.query.return_value = query_stmt
patched_runtime.session.execute.return_value = _FakeExecuteResult(
[
SimpleNamespace(
index_node_id="node-2",
content="segment-content",
index_node_hash="hash-2",
document_id="doc-2",
dataset_id="dataset-1",
)
]
)
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"]))
documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"])
assert len(query_stmt.where_calls) == 2
assert len(documents) == 1
assert documents[0].page_content == "segment-content"
assert documents[0].metadata["doc_id"] == "node-2"
assert documents[0].metadata["doc_hash"] == "hash-2"
def test_delete_removes_keyword_table_and_optional_file(monkeypatch, patched_runtime):
db_keyword = _dataset_keyword_table(data_source_type="database")
file_keyword = _dataset_keyword_table(data_source_type="object_storage")
keyword_db = Jieba(_dataset(db_keyword))
keyword_db.delete()
patched_runtime.storage.delete.assert_not_called()
keyword_file = Jieba(_dataset(file_keyword))
keyword_file.delete()
patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt")
assert patched_runtime.session.delete.call_count == 2
assert patched_runtime.session.commit.call_count == 2
def test_save_dataset_keyword_table_to_database(monkeypatch, patched_runtime):
dataset_keyword_table = _dataset_keyword_table(data_source_type="database")
keyword = Jieba(_dataset(dataset_keyword_table))
keyword._save_dataset_keyword_table({"kw": {"node-1"}})
assert '"__type__":"keyword_table"' in dataset_keyword_table.keyword_table
assert '"index_id":"dataset-1"' in dataset_keyword_table.keyword_table
patched_runtime.session.commit.assert_called_once()
def test_save_dataset_keyword_table_to_file_storage(monkeypatch, patched_runtime):
dataset_keyword_table = _dataset_keyword_table(data_source_type="file")
keyword = Jieba(_dataset(dataset_keyword_table))
patched_runtime.storage.exists.return_value = True
keyword._save_dataset_keyword_table({"kw": {"node-1"}})
patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt")
patched_runtime.storage.save.assert_called_once()
save_args = patched_runtime.storage.save.call_args.args
assert save_args[0] == "keyword_files/tenant-1/dataset-1.txt"
assert isinstance(save_args[1], bytes)
def test_get_dataset_keyword_table_returns_existing_table_data(monkeypatch, patched_runtime):
existing = _dataset_keyword_table(
keyword_table_dict={"__type__": "keyword_table", "__data__": {"table": {"kw": ["node-1"]}}}
)
keyword = Jieba(_dataset(existing))
assert keyword._get_dataset_keyword_table() == {"kw": ["node-1"]}
missing_payload = _dataset_keyword_table(keyword_table_dict=None)
keyword_with_missing_payload = Jieba(_dataset(missing_payload))
assert keyword_with_missing_payload._get_dataset_keyword_table() == {}
def test_get_dataset_keyword_table_creates_table_when_missing(monkeypatch, patched_runtime):
created_tables: list[SimpleNamespace] = []
def _fake_dataset_keyword_table(**kwargs):
kwargs.setdefault("keyword_table", "")
kwargs.setdefault("keyword_table_dict", None)
table = SimpleNamespace(**kwargs)
created_tables.append(table)
return table
keyword = Jieba(_dataset(dataset_keyword_table=None))
monkeypatch.setattr(jieba_module, "DatasetKeywordTable", _fake_dataset_keyword_table)
monkeypatch.setattr(jieba_module.dify_config, "KEYWORD_DATA_SOURCE_TYPE", "database")
result = keyword._get_dataset_keyword_table()
assert result == {}
assert len(created_tables) == 1
assert created_tables[0].dataset_id == "dataset-1"
assert created_tables[0].data_source_type == "database"
assert '"index_id":"dataset-1"' in created_tables[0].keyword_table
patched_runtime.session.add.assert_called_once_with(created_tables[0])
patched_runtime.session.commit.assert_called_once()
def test_add_and_delete_ids_from_keyword_table_helpers():
keyword = Jieba(_dataset(_dataset_keyword_table()))
keyword_table = {"kw1": {"node-1"}, "kw2": {"node-1", "node-2"}}
updated = keyword._add_text_to_keyword_table(keyword_table, "node-3", ["kw1", "kw3"])
assert updated["kw1"] == {"node-1", "node-3"}
assert updated["kw3"] == {"node-3"}
deleted = keyword._delete_ids_from_keyword_table(updated, ["node-1", "node-3"])
assert "kw3" not in deleted
assert "kw1" not in deleted
assert deleted["kw2"] == {"node-2"}
def test_retrieve_ids_by_query_ranks_by_keyword_frequency(monkeypatch):
keyword = Jieba(_dataset(_dataset_keyword_table()))
handler = MagicMock()
handler.extract_keywords.return_value = ["kw-a", "kw-b"]
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
ranked_ids = keyword._retrieve_ids_by_query(
{"kw-a": {"node-1", "node-2"}, "kw-b": {"node-2"}, "kw-c": {"node-3"}},
"query",
k=1,
)
assert ranked_ids == ["node-2"]
def test_update_segment_keywords_updates_when_segment_exists(monkeypatch, patched_runtime):
class _FakeDocumentSegment:
dataset_id = _Field("dataset_id")
index_node_id = _Field("index_node_id")
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect())
keyword = Jieba(_dataset(_dataset_keyword_table()))
segment = SimpleNamespace(keywords=[])
patched_runtime.session.scalar.return_value = segment
keyword._update_segment_keywords("dataset-1", "node-1", ["kw1", "kw2"])
assert segment.keywords == ["kw1", "kw2"]
patched_runtime.session.add.assert_called_once_with(segment)
patched_runtime.session.commit.assert_called_once()
patched_runtime.session.reset_mock()
patched_runtime.session.scalar.return_value = None
keyword._update_segment_keywords("dataset-1", "node-missing", ["kw3"])
patched_runtime.session.add.assert_not_called()
patched_runtime.session.commit.assert_not_called()
def test_create_segment_keywords_and_update_segment_keywords_index(monkeypatch):
keyword = Jieba(_dataset(_dataset_keyword_table()))
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock())
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
keyword.create_segment_keywords("node-1", ["kw"])
keyword._update_segment_keywords.assert_called_once_with("dataset-1", "node-1", ["kw"])
keyword._save_dataset_keyword_table.assert_called_once()
keyword._save_dataset_keyword_table.reset_mock()
keyword.update_segment_keywords_index("node-2", ["kw2"])
keyword._save_dataset_keyword_table.assert_called_once()
def test_multi_create_segment_keywords_uses_provided_and_extracted_keywords(monkeypatch):
keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=2))
handler = MagicMock()
handler.extract_keywords.return_value = {"auto"}
monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler)
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={}))
monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock())
first_segment = SimpleNamespace(index_node_id="node-1", content="first content", keywords=None)
second_segment = SimpleNamespace(index_node_id="node-2", content="second content", keywords=None)
keyword.multi_create_segment_keywords(
[
{"segment": first_segment, "keywords": ["manual"]},
{"segment": second_segment, "keywords": []},
]
)
assert first_segment.keywords == ["manual"]
assert second_segment.keywords == ["auto"]
saved_table = keyword._save_dataset_keyword_table.call_args.args[0]
assert saved_table["manual"] == {"node-1"}
assert saved_table["auto"] == {"node-2"}
def test_set_orjson_default_and_dumps_with_sets():
assert set(set_orjson_default({"a", "b"})) == {"a", "b"}
with pytest.raises(TypeError, match="is not JSON serializable"):
set_orjson_default(("not", "a", "set"))
payload = {"items": {"a", "b"}}
json_payload = dumps_with_sets(payload)
decoded = json.loads(json_payload)
assert set(decoded["items"]) == {"a", "b"}

View File

@ -0,0 +1,142 @@
import sys
import types
from types import SimpleNamespace
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class _DummyTFIDF:
def __init__(self):
self.stop_words = set()
@staticmethod
def extract_tags(sentence: str, top_k: int | None = 20, **kwargs):
return ["alpha_beta", "during", "gamma"]
def _install_fake_jieba_modules(
monkeypatch,
analyse_module: types.ModuleType,
jieba_attrs: dict[str, object] | None = None,
tfidf_module: types.ModuleType | None = None,
):
jieba_module = types.ModuleType("jieba")
jieba_module.__path__ = []
if jieba_attrs:
for key, value in jieba_attrs.items():
setattr(jieba_module, key, value)
jieba_module.analyse = analyse_module
analyse_module.__package__ = "jieba"
monkeypatch.setitem(sys.modules, "jieba", jieba_module)
monkeypatch.setitem(sys.modules, "jieba.analyse", analyse_module)
if tfidf_module is not None:
monkeypatch.setitem(sys.modules, "jieba.analyse.tfidf", tfidf_module)
else:
monkeypatch.delitem(sys.modules, "jieba.analyse.tfidf", raising=False)
def test_init_uses_existing_default_tfidf(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
default_tfidf = _DummyTFIDF()
analyse_module.default_tfidf = default_tfidf
_install_fake_jieba_modules(monkeypatch, analyse_module)
handler = JiebaKeywordTableHandler()
assert handler._tfidf is default_tfidf
assert handler._tfidf.stop_words == STOPWORDS
def test_load_tfidf_extractor_uses_tfidf_class_and_caches_default(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
analyse_module.default_tfidf = None
class _TFIDFFactory(_DummyTFIDF):
pass
analyse_module.TFIDF = _TFIDFFactory
_install_fake_jieba_modules(monkeypatch, analyse_module)
handler = JiebaKeywordTableHandler()
assert isinstance(handler._tfidf, _TFIDFFactory)
assert analyse_module.default_tfidf is handler._tfidf
def test_load_tfidf_extractor_imports_from_tfidf_submodule(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
analyse_module.default_tfidf = None
tfidf_module = types.ModuleType("jieba.analyse.tfidf")
class _ImportedTFIDF(_DummyTFIDF):
pass
tfidf_module.TFIDF = _ImportedTFIDF
_install_fake_jieba_modules(monkeypatch, analyse_module, tfidf_module=tfidf_module)
handler = JiebaKeywordTableHandler()
assert isinstance(handler._tfidf, _ImportedTFIDF)
assert analyse_module.default_tfidf is handler._tfidf
def test_load_tfidf_extractor_falls_back_when_tfidf_unavailable(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
analyse_module.default_tfidf = None
_install_fake_jieba_modules(monkeypatch, analyse_module)
handler = JiebaKeywordTableHandler()
fallback_keywords = handler._tfidf.extract_tags("one two two and three", topK=1)
assert fallback_keywords == ["two"]
def test_build_fallback_tfidf_uses_lcut_when_available(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
_install_fake_jieba_modules(monkeypatch, analyse_module, jieba_attrs={"lcut": lambda _: ["x", "x", "y"]})
tfidf = JiebaKeywordTableHandler._build_fallback_tfidf()
assert tfidf.extract_tags("ignored", topK=1) == ["x"]
def test_build_fallback_tfidf_uses_cut_when_lcut_is_missing(monkeypatch):
analyse_module = types.ModuleType("jieba.analyse")
_install_fake_jieba_modules(
monkeypatch,
analyse_module,
jieba_attrs={"cut": lambda _: iter(["foo", "foo", "bar"])},
)
tfidf = JiebaKeywordTableHandler._build_fallback_tfidf()
assert tfidf.extract_tags("ignored", topK=1) == ["foo"]
def test_extract_keywords_expands_subtokens():
handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler)
handler._tfidf = SimpleNamespace(extract_tags=lambda *_args, **_kwargs: ["alpha-beta", "during", "gamma"])
keywords = handler.extract_keywords("input text", max_keywords_per_chunk=3)
assert "alpha-beta" in keywords
assert "alpha" in keywords
assert "beta" in keywords
assert "during" in keywords
assert "gamma" in keywords
def test_expand_tokens_with_subtokens_filters_stopwords_from_subtokens():
handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler)
expanded = handler._expand_tokens_with_subtokens({"alpha-during-beta"})
assert "alpha-during-beta" in expanded
assert "alpha" in expanded
assert "beta" in expanded
assert "during" not in expanded

View File

@ -0,0 +1,6 @@
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
def test_stopwords_loaded():
assert "during" in STOPWORDS
assert "the" in STOPWORDS

View File

@ -0,0 +1,97 @@
from types import SimpleNamespace
import pytest
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
class _KeywordThatRaises(BaseKeyword):
def create(self, texts: list[Document], **kwargs):
return super().create(texts, **kwargs)
def add_texts(self, texts: list[Document], **kwargs):
return super().add_texts(texts, **kwargs)
def text_exists(self, id: str) -> bool:
return super().text_exists(id)
def delete_by_ids(self, ids: list[str]):
return super().delete_by_ids(ids)
def delete(self):
return super().delete()
def search(self, query: str, **kwargs):
return super().search(query, **kwargs)
class _KeywordForHelpers(BaseKeyword):
def __init__(self, dataset, existing_ids: set[str] | None = None):
super().__init__(dataset)
self._existing_ids = existing_ids or set()
def create(self, texts: list[Document], **kwargs):
return self
def add_texts(self, texts: list[Document], **kwargs):
return None
def text_exists(self, id: str) -> bool:
return id in self._existing_ids
def delete_by_ids(self, ids: list[str]):
return None
def delete(self):
return None
def search(self, query: str, **kwargs):
return []
def test_abstract_methods_raise_not_implemented():
keyword = _KeywordThatRaises(SimpleNamespace(id="dataset-1"))
with pytest.raises(NotImplementedError):
keyword.create([])
with pytest.raises(NotImplementedError):
keyword.add_texts([])
with pytest.raises(NotImplementedError):
keyword.text_exists("doc-1")
with pytest.raises(NotImplementedError):
keyword.delete_by_ids(["doc-1"])
with pytest.raises(NotImplementedError):
keyword.delete()
with pytest.raises(NotImplementedError):
keyword.search("query")
def test_filter_duplicate_texts_removes_existing_doc_ids():
keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"), existing_ids={"duplicate"})
texts = [
Document(page_content="keep", metadata={"doc_id": "keep"}),
Document(page_content="duplicate", metadata={"doc_id": "duplicate"}),
SimpleNamespace(page_content="without-metadata", metadata=None),
]
filtered = keyword._filter_duplicate_texts(texts)
assert [text.metadata["doc_id"] for text in filtered if text.metadata] == ["keep"]
assert any(text.metadata is None for text in filtered)
def test_get_uuids_returns_only_docs_with_metadata():
keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"))
texts = [
Document(page_content="doc-1", metadata={"doc_id": "doc-1"}),
Document(page_content="doc-2", metadata={"doc_id": "doc-2"}),
SimpleNamespace(page_content="doc-3", metadata=None),
]
assert keyword._get_uuids(texts) == ["doc-1", "doc-2"]

View File

@ -0,0 +1,84 @@
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.keyword.keyword_type import KeyWordType
from core.rag.models.document import Document
def test_get_keyword_factory_returns_jieba_factory(monkeypatch):
fake_module = types.ModuleType("core.rag.datasource.keyword.jieba.jieba")
class FakeJieba:
pass
fake_module.Jieba = FakeJieba
monkeypatch.setitem(sys.modules, "core.rag.datasource.keyword.jieba.jieba", fake_module)
assert Keyword.get_keyword_factory(KeyWordType.JIEBA) is FakeJieba
def test_get_keyword_factory_raises_for_unsupported_type():
with pytest.raises(ValueError, match="Keyword store unsupported is not supported"):
Keyword.get_keyword_factory("unsupported")
def test_keyword_initialization_uses_configured_factory(monkeypatch):
dataset = SimpleNamespace(id="dataset-1")
fake_processor = MagicMock()
monkeypatch.setattr("core.rag.datasource.keyword.keyword_factory.dify_config.KEYWORD_STORE", KeyWordType.JIEBA)
monkeypatch.setattr(Keyword, "get_keyword_factory", staticmethod(lambda keyword_type: lambda _: fake_processor))
keyword = Keyword(dataset)
assert keyword._keyword_processor is fake_processor
def test_keyword_methods_forward_to_processor():
processor = MagicMock()
processor.text_exists.return_value = True
processor.search.return_value = [Document(page_content="matched", metadata={"doc_id": "doc-1"})]
keyword = Keyword.__new__(Keyword)
keyword._keyword_processor = processor
docs = [Document(page_content="doc", metadata={"doc_id": "doc-1"})]
keyword.create(docs, foo="bar")
keyword.add_texts(docs, batch=True)
assert keyword.text_exists("doc-1") is True
keyword.delete_by_ids(["doc-1"])
keyword.delete()
assert keyword.search("query", top_k=1) == processor.search.return_value
processor.create.assert_called_once_with(docs, foo="bar")
processor.add_texts.assert_called_once_with(docs, batch=True)
processor.text_exists.assert_called_once_with("doc-1")
processor.delete_by_ids.assert_called_once_with(["doc-1"])
processor.delete.assert_called_once()
processor.search.assert_called_once_with("query", top_k=1)
def test_keyword_getattr_returns_callable_and_raises_for_invalid_attributes():
class Processor:
value = 1
@staticmethod
def custom():
return "ok"
keyword = Keyword.__new__(Keyword)
keyword._keyword_processor = Processor()
assert keyword.custom() == "ok"
with pytest.raises(AttributeError):
_ = keyword.value
keyword._keyword_processor = None
with pytest.raises(AttributeError):
_ = keyword.missing_method

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,74 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
def test_validate_distance_function_accepts_supported_values():
factory = AlibabaCloudMySQLVectorFactory()
assert factory._validate_distance_function("cosine") == "cosine"
assert factory._validate_distance_function("euclidean") == "euclidean"
def test_validate_distance_function_rejects_unsupported_values():
factory = AlibabaCloudMySQLVectorFactory()
with pytest.raises(ValueError, match="Invalid distance function"):
factory._validate_distance_function("dot_product")
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch):
factory = AlibabaCloudMySQLVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}},
index_struct=None,
)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "cosine")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 6)
with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch):
factory = AlibabaCloudMySQLVectorFactory()
dataset = SimpleNamespace(
id="dataset-2",
index_struct_dict=None,
index_struct=None,
)
monkeypatch.setattr(alibaba_module.Dataset, "gen_collection_name_by_id", lambda dataset_id: f"COL_{dataset_id}")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5)
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "euclidean")
monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 12)
with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
vector_cls.assert_called_once()
assert vector_cls.call_args.kwargs["collection_name"] == "COL_dataset-2"
assert dataset.index_struct is not None

View File

@ -0,0 +1,133 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from core.rag.models.document import Document
def test_init_prefers_openapi_when_api_config_is_provided():
api_config = AnalyticdbVectorOpenAPIConfig(
access_key_id="ak",
access_key_secret="sk",
region_id="cn-hangzhou",
instance_id="instance-1",
account="account",
account_password="password",
namespace="dify",
namespace_password="ns-password",
)
with patch.object(analyticdb_module, "AnalyticdbVectorOpenAPI", return_value="openapi_runner") as openapi_cls:
vector = AnalyticdbVector("COLLECTION", api_config=api_config, sql_config=None)
assert vector.analyticdb_vector == "openapi_runner"
openapi_cls.assert_called_once_with("COLLECTION", api_config)
def test_init_uses_sql_implementation_when_api_config_is_missing():
sql_config = AnalyticdbVectorBySqlConfig(
host="localhost",
port=5432,
account="account",
account_password="password",
min_connection=1,
max_connection=2,
namespace="dify",
)
with patch.object(analyticdb_module, "AnalyticdbVectorBySql", return_value="sql_runner") as sql_cls:
vector = AnalyticdbVector("COLLECTION", api_config=None, sql_config=sql_config)
assert vector.analyticdb_vector == "sql_runner"
sql_cls.assert_called_once_with("COLLECTION", sql_config)
def test_init_raises_when_both_configs_are_missing():
with pytest.raises(ValueError, match="Either api_config or sql_config must be provided"):
AnalyticdbVector("COLLECTION", api_config=None, sql_config=None)
def test_vector_methods_delegate_to_underlying_implementation():
runner = MagicMock()
runner.search_by_vector.return_value = [Document(page_content="v", metadata={"doc_id": "1"})]
runner.search_by_full_text.return_value = [Document(page_content="t", metadata={"doc_id": "2"})]
runner.text_exists.return_value = True
vector = AnalyticdbVector.__new__(AnalyticdbVector)
vector.analyticdb_vector = runner
texts = [Document(page_content="hello", metadata={"doc_id": "d1"})]
vector.create(texts=texts, embeddings=[[0.1, 0.2]])
vector.add_texts(documents=texts, embeddings=[[0.1, 0.2]])
assert vector.text_exists("d1") is True
vector.delete_by_ids(["d1"])
vector.delete_by_metadata_field("document_id", "doc-1")
assert vector.search_by_vector([0.1, 0.2], top_k=2) == runner.search_by_vector.return_value
assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value
vector.delete()
runner._create_collection_if_not_exists.assert_called_once_with(2)
runner.add_texts.assert_any_call(texts, [[0.1, 0.2]])
runner.delete_by_ids.assert_called_once_with(["d1"])
runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1")
runner.delete.assert_called_once()
def test_get_type_is_analyticdb():
vector = AnalyticdbVector.__new__(AnalyticdbVector)
assert vector.get_type() == "analyticdb"
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
factory = AnalyticdbVectorFactory()
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(analyticdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", None)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_ID", "ak")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_SECRET", "sk")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_REGION_ID", "cn-hz")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_INSTANCE_ID", "instance")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE_PASSWORD", "ns-password")
with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
args = vector_cls.call_args.args
assert args[0] == "auto_collection"
assert isinstance(args[1], AnalyticdbVectorOpenAPIConfig)
assert args[2] is None
assert dataset.index_struct is not None
def test_factory_builds_sql_config_when_host_is_present(monkeypatch):
factory = AnalyticdbVectorFactory()
dataset = SimpleNamespace(
id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", "127.0.0.1")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PORT", 5432)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password")
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MIN_CONNECTION", 1)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MAX_CONNECTION", 3)
monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify")
with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
args = vector_cls.call_args.args
assert args[0] == "existing"
assert args[1] is None
assert isinstance(args[2], AnalyticdbVectorBySqlConfig)

View File

@ -0,0 +1,384 @@
import json
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.models.document import Document
def _request_class(name: str):
class _Request:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
_Request.__name__ = name
return _Request
def _install_openapi_stubs(monkeypatch):
gpdb_package = types.ModuleType("alibabacloud_gpdb20160503")
gpdb_package.__path__ = []
gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models")
for class_name in [
"InitVectorDatabaseRequest",
"DescribeNamespaceRequest",
"CreateNamespaceRequest",
"DescribeCollectionRequest",
"CreateCollectionRequest",
"UpsertCollectionDataRequestRows",
"UpsertCollectionDataRequest",
"QueryCollectionDataRequest",
"DeleteCollectionDataRequest",
"DeleteCollectionRequest",
]:
setattr(gpdb_models, class_name, _request_class(class_name))
class _Client:
def __init__(self, config):
self.config = config
gpdb_client = types.ModuleType("alibabacloud_gpdb20160503.client")
gpdb_client.Client = _Client
gpdb_package.models = gpdb_models
tea_openapi = types.ModuleType("alibabacloud_tea_openapi")
tea_openapi.__path__ = []
tea_openapi_models = types.ModuleType("alibabacloud_tea_openapi.models")
class OpenApiConfig:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
tea_openapi_models.Config = OpenApiConfig
tea_openapi.models = tea_openapi_models
tea_package = types.ModuleType("Tea")
tea_package.__path__ = []
tea_exceptions = types.ModuleType("Tea.exceptions")
class TeaError(Exception):
def __init__(self, status_code=None, **kwargs):
super().__init__("TeaException")
status_code = kwargs.get("statusCode", status_code)
self.statusCode = status_code
self.status_code = status_code
tea_exceptions.TeaException = TeaError
tea_package.exceptions = tea_exceptions
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503", gpdb_package)
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.models", gpdb_models)
monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.client", gpdb_client)
monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi", tea_openapi)
monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi.models", tea_openapi_models)
monkeypatch.setitem(sys.modules, "Tea", tea_package)
monkeypatch.setitem(sys.modules, "Tea.exceptions", tea_exceptions)
return SimpleNamespace(models=gpdb_models, TeaException=TeaError, OpenApiConfig=OpenApiConfig)
def _config() -> AnalyticdbVectorOpenAPIConfig:
return AnalyticdbVectorOpenAPIConfig(
access_key_id="ak",
access_key_secret="sk",
region_id="cn-hangzhou",
instance_id="instance-1",
account="account",
account_password="password",
namespace="dify",
namespace_password="ns-password",
)
@pytest.mark.parametrize(
("field", "value", "error_message"),
[
("access_key_id", "", "ANALYTICDB_KEY_ID"),
("access_key_secret", "", "ANALYTICDB_KEY_SECRET"),
("region_id", "", "ANALYTICDB_REGION_ID"),
("instance_id", "", "ANALYTICDB_INSTANCE_ID"),
("account", "", "ANALYTICDB_ACCOUNT"),
("account_password", "", "ANALYTICDB_PASSWORD"),
("namespace_password", "", "ANALYTICDB_NAMESPACE_PASSWORD"),
],
)
def test_openapi_config_validation(field, value, error_message):
values = _config().model_dump()
values[field] = value
with pytest.raises(ValueError, match=error_message):
AnalyticdbVectorOpenAPIConfig.model_validate(values)
def test_openapi_config_to_client_params():
config = _config()
params = config.to_analyticdb_client_params()
assert params["access_key_id"] == "ak"
assert params["access_key_secret"] == "sk"
assert params["region_id"] == "cn-hangzhou"
assert params["read_timeout"] == 60000
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
initialize_mock = MagicMock()
monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock)
vector = AnalyticdbVectorOpenAPI("COLLECTION_1", _config())
assert vector._collection_name == "collection_1"
assert isinstance(vector._client_config, stubs.OpenApiConfig)
assert vector._client_config.user_agent == "dify"
assert vector._client_config.access_key_id == "ak"
assert vector._client.config is vector._client_config
initialize_mock.assert_called_once_with()
def test_initialize_skips_when_cached(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._initialize_vector_database = MagicMock()
vector._create_namespace_if_not_exists = MagicMock()
vector._initialize()
vector._initialize_vector_database.assert_not_called()
vector._create_namespace_if_not_exists.assert_not_called()
def test_initialize_runs_when_cache_is_missing(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._initialize_vector_database = MagicMock()
vector._create_namespace_if_not_exists = MagicMock()
vector._initialize()
vector._initialize_vector_database.assert_called_once()
vector._create_namespace_if_not_exists.assert_called_once()
openapi_module.redis_client.set.assert_called_once()
def test_initialize_vector_database_calls_openapi_client(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._client = MagicMock()
vector._initialize_vector_database()
request = vector._client.init_vector_database.call_args.args[0]
assert request.dbinstance_id == "instance-1"
assert request.region_id == "cn-hangzhou"
assert request.manager_account == "account"
assert request.manager_account_password == "password"
def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._client = MagicMock()
vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=404)
vector._create_namespace_if_not_exists()
vector._client.create_namespace.assert_called_once()
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._client = MagicMock()
vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=500)
with pytest.raises(ValueError, match="failed to create namespace"):
vector._create_namespace_if_not_exists()
def test_create_namespace_noop_when_namespace_exists(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
vector._client = MagicMock()
vector._create_namespace_if_not_exists()
vector._client.describe_namespace.assert_called_once()
vector._client.create_namespace.assert_not_called()
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404)
vector._create_collection_if_not_exists(embedding_dimension=1024)
vector._client.create_collection.assert_called_once()
openapi_module.redis_client.set.assert_called_once()
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._create_collection_if_not_exists(embedding_dimension=1024)
vector._client.describe_collection.assert_not_called()
vector._client.create_collection.assert_not_called()
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
stubs = _install_openapi_stubs(monkeypatch)
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500)
with pytest.raises(ValueError, match="failed to create collection collection_1"):
vector._create_collection_if_not_exists(embedding_dimension=512)
def test_openapi_add_delete_and_search_methods(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
documents = [
Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}),
SimpleNamespace(page_content="doc 2", metadata=None),
]
embeddings = [[0.1, 0.2], [0.2, 0.3]]
vector.add_texts(documents, embeddings)
upsert_request = vector._client.upsert_collection_data.call_args.args[0]
assert upsert_request.collection == "collection_1"
assert len(upsert_request.rows) == 1
vector._client.query_collection_data.return_value = SimpleNamespace(
body=SimpleNamespace(matches=SimpleNamespace(match=[SimpleNamespace()]))
)
assert vector.text_exists("d1") is True
vector.delete_by_ids(["d1", "d2"])
request = vector._client.delete_collection_data.call_args.args[0]
assert request.collection_data_filter == "ref_doc_id IN ('d1','d2')"
vector.delete_by_metadata_field("document_id", "doc-1")
request = vector._client.delete_collection_data.call_args.args[0]
assert request.collection_data_filter == "metadata_ ->> 'document_id' = 'doc-1'"
match_high = SimpleNamespace(
score=0.9,
metadata={"metadata_": json.dumps({"document_id": "doc-1"}), "page_content": "high"},
values=SimpleNamespace(value=[1.0, 2.0]),
)
match_low = SimpleNamespace(
score=0.1,
metadata={"metadata_": json.dumps({"document_id": "doc-2"}), "page_content": "low"},
values=SimpleNamespace(value=[3.0, 4.0]),
)
vector._client.query_collection_data.return_value = SimpleNamespace(
body=SimpleNamespace(matches=SimpleNamespace(match=[match_low, match_high]))
)
docs_by_vector = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
assert len(docs_by_vector) == 1
assert docs_by_vector[0].page_content == "high"
assert docs_by_vector[0].metadata["score"] == 0.9
docs_by_text = vector.search_by_full_text("hello", top_k=2, score_threshold=0.2)
assert len(docs_by_text) == 1
assert docs_by_text[0].page_content == "high"
def test_text_exists_returns_false_when_matches_empty(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._client.query_collection_data.return_value = SimpleNamespace(
body=SimpleNamespace(matches=SimpleNamespace(match=[]))
)
assert vector.text_exists("missing-id") is False
def test_openapi_delete_success(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector.delete()
vector._client.delete_collection.assert_called_once()
def test_openapi_delete_propagates_errors(monkeypatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
vector.config = _config()
vector._client = MagicMock()
vector._client.delete_collection.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
vector.delete()

View File

@ -0,0 +1,427 @@
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
import psycopg2.errors
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import (
AnalyticdbVectorBySql,
AnalyticdbVectorBySqlConfig,
)
from core.rag.models.document import Document
def _config_values() -> dict:
return {
"host": "localhost",
"port": 5432,
"account": "account",
"account_password": "password",
"min_connection": 1,
"max_connection": 2,
"namespace": "dify",
}
@pytest.mark.parametrize(
("field", "value", "error_message"),
[
("host", "", "ANALYTICDB_HOST"),
("port", 0, "ANALYTICDB_PORT"),
("account", "", "ANALYTICDB_ACCOUNT"),
("account_password", "", "ANALYTICDB_PASSWORD"),
("min_connection", 0, "ANALYTICDB_MIN_CONNECTION"),
("max_connection", 0, "ANALYTICDB_MAX_CONNECTION"),
],
)
def test_sql_config_required_fields(field, value, error_message):
values = _config_values()
values[field] = value
with pytest.raises(ValueError, match=error_message):
AnalyticdbVectorBySqlConfig.model_validate(values)
def test_sql_config_rejects_min_connection_greater_than_max_connection():
values = _config_values()
values["min_connection"] = 10
values["max_connection"] = 2
with pytest.raises(ValueError, match="ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION"):
AnalyticdbVectorBySqlConfig.model_validate(values)
def test_initialize_skips_when_cache_exists(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._initialize_vector_database = MagicMock()
vector._initialize()
vector._initialize_vector_database.assert_not_called()
def test_initialize_runs_when_cache_is_missing(monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._initialize_vector_database = MagicMock()
vector._initialize()
vector._initialize_vector_database.assert_called_once()
sql_module.redis_client.set.assert_called_once()
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
pool_instance = MagicMock()
monkeypatch.setattr(sql_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool_instance))
pool = vector._create_connection_pool()
assert pool is pool_instance
sql_module.psycopg2.pool.SimpleConnectionPool.assert_called_once()
def test_get_cursor_context_manager_handles_connection_lifecycle():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
cursor = MagicMock()
connection = MagicMock()
connection.cursor.return_value = cursor
pool = MagicMock()
pool.getconn.return_value = connection
vector.pool = pool
with vector._get_cursor() as cur:
assert cur is cursor
cursor.close.assert_called_once()
connection.commit.assert_called_once()
pool.putconn.assert_called_once_with(connection)
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
monkeypatch.setattr(sql_module.uuid, "uuid4", lambda: "prefix-id")
monkeypatch.setattr(sql_module.psycopg2.extras, "execute_batch", MagicMock())
docs = [
Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}),
SimpleNamespace(page_content="doc 2", metadata=None),
]
vector.add_texts(docs, [[0.1, 0.2], [0.2, 0.3]])
execute_args = sql_module.psycopg2.extras.execute_batch.call_args.args
assert execute_args[0] is cursor
assert len(execute_args[2]) == 1
def test_text_exists_returns_true_and_false_based_on_query_result():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
cursor.fetchone.return_value = ("row",)
assert vector.text_exists("d1") is True
cursor.fetchone.return_value = None
assert vector.text_exists("d1") is False
def test_delete_by_ids_handles_empty_input_and_missing_table_error():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
vector.delete_by_ids([])
cursor.execute.assert_not_called()
cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist")
vector.delete_by_ids(["d1"])
def test_delete_by_metadata_field_handles_missing_table_error():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist")
vector.delete_by_metadata_field("document_id", "doc-1")
@pytest.mark.parametrize("invalid_top_k", [0, "x", -1])
def test_search_by_vector_validates_top_k(invalid_top_k):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_vector([0.1, 0.2], top_k=invalid_top_k)
def test_search_by_vector_returns_documents_above_threshold():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
cursor.__iter__.return_value = iter(
[
("id1", [1.0], 0.8, "content 1", {"doc_id": "id1", "document_id": "doc-1"}),
("id2", [2.0], 0.3, "content 2", {"doc_id": "id2", "document_id": "doc-2"}),
]
)
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].page_content == "content 1"
assert docs[0].metadata["score"] == 0.8
@pytest.mark.parametrize("invalid_top_k", [0, "x", -1])
def test_search_by_full_text_validates_top_k(invalid_top_k):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_full_text("query", top_k=invalid_top_k)
def test_search_by_full_text_returns_documents():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
cursor.__iter__.return_value = iter(
[
("id1", [1.0], "content 1", {"doc_id": "id1", "document_id": "doc-1"}, 0.9),
]
)
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
docs = vector.search_by_full_text("query", top_k=1, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.9
assert docs[0].page_content == "content 1"
def test_delete_drops_table():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
vector.delete()
cursor.execute.assert_called_once()
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch):
config = AnalyticdbVectorBySqlConfig(**_config_values())
created_pool = MagicMock()
monkeypatch.setattr(AnalyticdbVectorBySql, "_initialize", MagicMock())
monkeypatch.setattr(AnalyticdbVectorBySql, "_create_connection_pool", MagicMock(return_value=created_pool))
vector = AnalyticdbVectorBySql("My_Collection", config)
assert vector._collection_name == "my_collection"
assert vector.table_name == "dify.my_collection"
assert vector.databaseName == "knowledgebase"
assert vector.pool is created_pool
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
bootstrap_cursor = MagicMock()
bootstrap_connection = MagicMock()
bootstrap_connection.cursor.return_value = bootstrap_cursor
bootstrap_cursor.execute.side_effect = RuntimeError("database already exists")
monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection))
worker_cursor = MagicMock()
worker_connection = MagicMock()
worker_cursor.connection = worker_connection
def _execute(sql, *args, **kwargs):
if "CREATE TEXT SEARCH CONFIGURATION zh_cn" in sql:
raise RuntimeError("already exists")
worker_cursor.execute.side_effect = _execute
pooled_connection = MagicMock()
pooled_connection.cursor.return_value = worker_cursor
pool = MagicMock()
pool.getconn.return_value = pooled_connection
vector._create_connection_pool = MagicMock(return_value=pool)
vector._initialize_vector_database()
bootstrap_cursor.close.assert_called_once()
bootstrap_connection.close.assert_called_once()
vector._create_connection_pool.assert_called_once()
assert any(
"CREATE OR REPLACE FUNCTION public.to_tsquery_from_text" in call.args[0]
for call in worker_cursor.execute.call_args_list
)
assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list)
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
bootstrap_cursor = MagicMock()
bootstrap_connection = MagicMock()
bootstrap_connection.cursor.return_value = bootstrap_cursor
monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection))
worker_cursor = MagicMock()
worker_connection = MagicMock()
worker_cursor.connection = worker_connection
worker_cursor.execute.side_effect = RuntimeError("zhparser unavailable")
pooled_connection = MagicMock()
pooled_connection.cursor.return_value = worker_cursor
pool = MagicMock()
pool.getconn.return_value = pooled_connection
vector._create_connection_pool = MagicMock(return_value=pool)
with pytest.raises(RuntimeError, match="Failed to create zhparser extension"):
vector._initialize_vector_database()
worker_connection.rollback.assert_called_once()
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._collection_name = "collection"
vector.table_name = "dify.collection"
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
vector._create_collection_if_not_exists(embedding_dimension=3)
assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list)
assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list)
sql_module.redis_client.set.assert_called_once()
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._collection_name = "collection"
vector.table_name = "dify.collection"
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(sql_module.redis_client, "set", MagicMock())
cursor = MagicMock()
cursor.execute.side_effect = RuntimeError("permission denied")
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
with pytest.raises(RuntimeError, match="permission denied"):
vector._create_collection_if_not_exists(embedding_dimension=3)
def test_delete_methods_raise_when_error_is_not_missing_table():
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
cursor = MagicMock()
@contextmanager
def _cursor_context():
yield cursor
vector._get_cursor = _cursor_context
cursor.execute.side_effect = RuntimeError("unexpected delete failure")
with pytest.raises(RuntimeError, match="unexpected delete failure"):
vector.delete_by_ids(["doc-1"])
cursor.execute.side_effect = RuntimeError("unexpected metadata failure")
with pytest.raises(RuntimeError, match="unexpected metadata failure"):
vector.delete_by_metadata_field("document_id", "doc-1")

View File

@ -0,0 +1,542 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_pymochow_modules():
pymochow = types.ModuleType("pymochow")
pymochow.__path__ = []
pymochow_auth = types.ModuleType("pymochow.auth")
pymochow_auth.__path__ = []
pymochow_credentials = types.ModuleType("pymochow.auth.bce_credentials")
pymochow_configuration = types.ModuleType("pymochow.configuration")
pymochow_exception = types.ModuleType("pymochow.exception")
pymochow_model = types.ModuleType("pymochow.model")
pymochow_model.__path__ = []
pymochow_model_database = types.ModuleType("pymochow.model.database")
pymochow_model_enum = types.ModuleType("pymochow.model.enum")
pymochow_model_schema = types.ModuleType("pymochow.model.schema")
pymochow_model_table = types.ModuleType("pymochow.model.table")
class _SimpleObject:
def __init__(self, *args, **kwargs):
self.args = args
for key, value in kwargs.items():
setattr(self, key, value)
class ServerError(Exception):
def __init__(self, code):
super().__init__(f"server error {code}")
self.code = code
class ServerErrCode:
TABLE_NOT_EXIST = 1001
DB_ALREADY_EXIST = 1002
class IndexType:
__members__ = {"HNSW": "HNSW"}
class MetricType:
__members__ = {"IP": "IP"}
class IndexState:
NORMAL = "NORMAL"
class TableState:
NORMAL = "NORMAL"
class InvertedIndexAnalyzer:
DEFAULT_ANALYZER = "DEFAULT_ANALYZER"
class InvertedIndexParseMode:
COARSE_MODE = "COARSE_MODE"
class InvertedIndexFieldAttribute:
ANALYZED = "ANALYZED"
class FieldType:
STRING = "STRING"
TEXT = "TEXT"
JSON = "JSON"
FLOAT_VECTOR = "FLOAT_VECTOR"
pymochow.MochowClient = _SimpleObject
pymochow_credentials.BceCredentials = _SimpleObject
pymochow_configuration.Configuration = _SimpleObject
pymochow_exception.ServerError = ServerError
pymochow_model_database.Database = _SimpleObject
pymochow_model_enum.FieldType = FieldType
pymochow_model_enum.IndexState = IndexState
pymochow_model_enum.IndexType = IndexType
pymochow_model_enum.MetricType = MetricType
pymochow_model_enum.ServerErrCode = ServerErrCode
pymochow_model_enum.TableState = TableState
for cls_name in [
"AutoBuildRowCountIncrement",
"Field",
"FilteringIndex",
"HNSWParams",
"InvertedIndex",
"InvertedIndexParams",
"Schema",
"VectorIndex",
]:
setattr(pymochow_model_schema, cls_name, _SimpleObject)
pymochow_model_schema.InvertedIndexAnalyzer = InvertedIndexAnalyzer
pymochow_model_schema.InvertedIndexFieldAttribute = InvertedIndexFieldAttribute
pymochow_model_schema.InvertedIndexParseMode = InvertedIndexParseMode
for cls_name in ["AnnSearch", "BM25SearchRequest", "HNSWSearchParams", "Partition", "Row"]:
setattr(pymochow_model_table, cls_name, _SimpleObject)
pymochow.auth = pymochow_auth
pymochow.model = pymochow_model
pymochow_auth.bce_credentials = pymochow_credentials
pymochow_model.database = pymochow_model_database
pymochow_model.enum = pymochow_model_enum
pymochow_model.schema = pymochow_model_schema
pymochow_model.table = pymochow_model_table
modules = {
"pymochow": pymochow,
"pymochow.auth": pymochow_auth,
"pymochow.auth.bce_credentials": pymochow_credentials,
"pymochow.configuration": pymochow_configuration,
"pymochow.exception": pymochow_exception,
"pymochow.model": pymochow_model,
"pymochow.model.database": pymochow_model_database,
"pymochow.model.enum": pymochow_model_enum,
"pymochow.model.schema": pymochow_model_schema,
"pymochow.model.table": pymochow_model_table,
}
return modules
@pytest.fixture
def baidu_module(monkeypatch):
for name, module in _build_fake_pymochow_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.baidu.baidu_vector as module
return importlib.reload(module)
def test_baidu_config_validation(baidu_module):
values = {
"endpoint": "https://example.com",
"account": "account",
"api_key": "key",
"database": "database",
}
config = baidu_module.BaiduConfig.model_validate(values)
assert config.endpoint == "https://example.com"
for key, error_message in [
("endpoint", "BAIDU_VECTOR_DB_ENDPOINT"),
("account", "BAIDU_VECTOR_DB_ACCOUNT"),
("api_key", "BAIDU_VECTOR_DB_API_KEY"),
("database", "BAIDU_VECTOR_DB_DATABASE"),
]:
invalid = dict(values)
invalid[key] = ""
with pytest.raises(ValueError, match=error_message):
baidu_module.BaiduConfig.model_validate(invalid)
def test_get_search_result_handles_metadata_and_threshold(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
response = SimpleNamespace(
rows=[
{"row": {"page_content": "doc1", "metadata": '{"document_id":"d1"}'}, "score": 0.9},
{"row": {"page_content": "doc2", "metadata": {"document_id": "d2"}}, "score": 0.4},
{"row": {"page_content": "doc3", "metadata": 123}, "score": 0.95},
]
)
docs = vector._get_search_res(response, score_threshold=0.8)
assert len(docs) == 2
assert docs[0].page_content == "doc1"
assert docs[0].metadata["score"] == 0.9
assert docs[1].page_content == "doc3"
def test_delete_by_ids_and_delete_by_metadata_field(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
table = MagicMock()
vector._db = MagicMock()
vector._db.table.return_value = table
vector._collection_name = "collection_1"
vector.delete_by_ids([])
table.delete.assert_not_called()
vector.delete_by_ids(["id1", "id2"])
table.delete.assert_called_once()
table.delete.reset_mock()
vector.delete_by_metadata_field("source", 'abc"def')
delete_filter = table.delete.call_args.kwargs["filter"]
assert delete_filter == 'metadata["source"] = "abc\\"def"'
def test_delete_handles_table_not_exist_error_and_raises_for_other_codes(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
vector._db.drop_table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST)
vector.delete()
vector._db.drop_table.side_effect = baidu_module.ServerError(9999)
with pytest.raises(baidu_module.ServerError):
vector.delete()
def test_init_database_uses_existing_or_creates_when_missing(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._client = MagicMock()
vector._client_config = SimpleNamespace(database="my_db")
vector._client.list_databases.return_value = [SimpleNamespace(database_name="my_db")]
vector._client.database.return_value = "existing_db"
assert vector._init_database() == "existing_db"
vector._client.list_databases.return_value = []
vector._client.database.return_value = "created_db"
vector._client.create_database.side_effect = None
assert vector._init_database() == "created_db"
vector._client.create_database.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.DB_ALREADY_EXIST)
assert vector._init_database() == "created_db"
def test_table_existed_checks_table_access(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
vector._db.table.return_value = MagicMock()
assert vector._table_existed() is True
vector._db.table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST)
assert vector._table_existed() is False
vector._db.table.side_effect = baidu_module.ServerError(9999)
with pytest.raises(baidu_module.ServerError):
vector._table_existed()
def test_search_methods_delegate_to_database_table(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
vector._get_search_res = MagicMock(return_value=[Document(page_content="doc", metadata={"doc_id": "1"})])
table = MagicMock()
vector._db.table.return_value = table
table.search.return_value = "vector_result"
table.bm25_search.return_value = "bm25_result"
result1 = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2)
result2 = vector.search_by_full_text("query", top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2)
assert result1 == vector._get_search_res.return_value
assert result2 == vector._get_search_res.return_value
assert vector._get_search_res.call_count == 2
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch):
factory = baidu_module.BaiduVectorFactory()
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300)
with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "auto_collection"
assert dataset.index_struct is not None
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch):
init_client = MagicMock(return_value="client")
init_database = MagicMock(return_value="database")
monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client)
monkeypatch.setattr(baidu_module.BaiduVector, "_init_database", init_database)
config = baidu_module.BaiduConfig(
endpoint="https://example.com",
account="account",
api_key="key",
database="db",
)
vector = baidu_module.BaiduVector(collection_name="my_collection", config=config)
assert vector.get_type() == baidu_module.VectorType.BAIDU
assert vector.to_index_struct()["vector_store"]["class_prefix"] == "my_collection"
assert vector._client == "client"
assert vector._db == "database"
vector._create_table = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="p1", metadata={"doc_id": "d1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_table.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_batches_rows(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
table = MagicMock()
vector._db = MagicMock()
vector._db.table.return_value = table
docs = [
Document(page_content="doc-1", metadata={"doc_id": "id-1", "document_id": "doc-1"}),
Document(page_content="doc-2", metadata={"doc_id": "id-2", "document_id": "doc-2"}),
]
vector.add_texts(docs, [[0.1, 0.2], [0.3, 0.4]])
assert table.upsert.call_count == 1
inserted_rows = table.upsert.call_args.kwargs["rows"]
assert len(inserted_rows) == 2
def test_add_texts_batches_more_than_batch_size(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
table = MagicMock()
vector._db = MagicMock()
vector._db.table.return_value = table
docs = [
Document(page_content=f"doc-{idx}", metadata={"doc_id": f"id-{idx}", "document_id": f"doc-{idx}"})
for idx in range(1001)
]
embeddings = [[0.1, 0.2] for _ in range(1001)]
vector.add_texts(docs, embeddings)
assert table.upsert.call_count == 2
assert len(table.upsert.call_args_list[0].kwargs["rows"]) == 1000
assert len(table.upsert.call_args_list[1].kwargs["rows"]) == 1
def test_text_exists_returns_false_when_query_code_is_not_success(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
table = MagicMock()
vector._db = MagicMock()
vector._db.table.return_value = table
table.query.return_value = SimpleNamespace(code=0)
assert vector.text_exists("id-1") is True
table.query.return_value = SimpleNamespace(code=1)
assert vector.text_exists("id-1") is False
table.query.return_value = None
assert vector.text_exists("id-1") is False
def test_get_search_result_handles_invalid_metadata_json(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
response = SimpleNamespace(rows=[{"row": {"page_content": "doc1", "metadata": "{bad json"}, "score": 0.7}])
docs = vector._get_search_res(response, score_threshold=0.1)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.7
assert "document_id" not in docs[0].metadata
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch):
credentials = MagicMock(return_value="credentials")
configuration = MagicMock(return_value="configuration")
client_cls = MagicMock(return_value="client")
monkeypatch.setattr(baidu_module, "BceCredentials", credentials)
monkeypatch.setattr(baidu_module, "Configuration", configuration)
monkeypatch.setattr(baidu_module, "MochowClient", client_cls)
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint")
client = vector._init_client(config)
assert client == "client"
credentials.assert_called_once_with("account", "key")
configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint")
client_cls.assert_called_once_with("configuration")
def test_init_database_raises_for_unknown_create_database_error(baidu_module):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._client = MagicMock()
vector._client_config = SimpleNamespace(database="my_db")
vector._client.list_databases.return_value = []
vector._client.create_database.side_effect = baidu_module.ServerError(9999)
with pytest.raises(baidu_module.ServerError):
vector._init_database()
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._client_config = SimpleNamespace(
index_type="HNSW",
metric_type="IP",
inverted_index_analyzer="DEFAULT_ANALYZER",
inverted_index_parser_mode="COARSE_MODE",
auto_build_row_count_increment=500,
auto_build_row_count_increment_ratio=0.05,
rebuild_index_timeout_in_seconds=300,
replicas=1,
shard=1,
)
vector._db = MagicMock()
table = MagicMock()
table.state = baidu_module.TableState.NORMAL
vector._db.describe_table.return_value = table
vector._table_existed = MagicMock(return_value=False)
vector.delete = MagicMock()
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(baidu_module.redis_client, "set", MagicMock())
monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None)
monkeypatch.setattr(vector, "_wait_for_index_ready", MagicMock())
# Cached table skips all work.
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=1))
vector._create_table(3)
vector._db.create_table.assert_not_called()
# Existing table also skips creation.
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
vector._table_existed.return_value = True
vector._create_table(3)
vector._db.create_table.assert_not_called()
# Create table when cache is empty and table does not exist.
vector._table_existed.return_value = False
vector._create_table(3)
vector._db.create_table.assert_called_once()
baidu_module.redis_client.set.assert_called_once_with("vector_indexing_collection_1", 1, ex=3600)
table.rebuild_index.assert_called_once_with(vector.vector_index)
vector._wait_for_index_ready.assert_called_once_with(table, 3600)
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
vector._table_existed = MagicMock(return_value=False)
vector.delete = MagicMock()
vector._client_config = SimpleNamespace(
index_type="INVALID",
metric_type="IP",
inverted_index_analyzer="DEFAULT_ANALYZER",
inverted_index_parser_mode="COARSE_MODE",
auto_build_row_count_increment=500,
auto_build_row_count_increment_ratio=0.05,
rebuild_index_timeout_in_seconds=300,
replicas=1,
shard=1,
)
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
with pytest.raises(ValueError, match="unsupported index_type"):
vector._create_table(3)
vector._client_config.index_type = "HNSW"
vector._client_config.metric_type = "INVALID"
with pytest.raises(ValueError, match="unsupported metric_type"):
vector._create_table(3)
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._client_config = SimpleNamespace(
index_type="HNSW",
metric_type="IP",
inverted_index_analyzer="DEFAULT_ANALYZER",
inverted_index_parser_mode="COARSE_MODE",
auto_build_row_count_increment=500,
auto_build_row_count_increment_ratio=0.05,
rebuild_index_timeout_in_seconds=300,
replicas=1,
shard=1,
)
vector._db = MagicMock()
vector._db.describe_table.return_value = SimpleNamespace(state="CREATING")
vector._table_existed = MagicMock(return_value=False)
vector.delete = MagicMock()
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None)
monkeypatch.setattr(baidu_module.time, "time", MagicMock(side_effect=[0, 301]))
with pytest.raises(TimeoutError, match="Table creation timeout"):
vector._create_table(3)
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch):
factory = baidu_module.BaiduVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE")
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05)
monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300)
with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"

View File

@ -0,0 +1,199 @@
import importlib
import sys
import types
from collections import UserDict
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_chroma_modules():
chroma = types.ModuleType("chromadb")
chroma.DEFAULT_TENANT = "default_tenant"
chroma.DEFAULT_DATABASE = "default_database"
class Settings:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class QueryResult(UserDict):
pass
class _Collection:
def __init__(self):
self.upsert = MagicMock()
self.delete = MagicMock()
self.query = MagicMock()
self.get = MagicMock(return_value={})
class _Client:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.collection = _Collection()
self.get_or_create_collection = MagicMock(return_value=self.collection)
self.delete_collection = MagicMock()
chroma.Settings = Settings
chroma.QueryResult = QueryResult
chroma.HttpClient = _Client
return chroma
@pytest.fixture
def chroma_module(monkeypatch):
fake_chroma = _build_fake_chroma_modules()
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
import core.rag.datasource.vdb.chroma.chroma_vector as module
return importlib.reload(module)
def test_chroma_config_to_params_builds_expected_payload(chroma_module):
config = chroma_module.ChromaConfig(
host="localhost",
port=8000,
tenant="tenant-1",
database="db-1",
auth_provider="provider",
auth_credentials="credentials",
)
params = config.to_chroma_params()
assert params["host"] == "localhost"
assert params["port"] == 8000
assert params["tenant"] == "tenant-1"
assert params["database"] == "db-1"
assert params["ssl"] is False
assert params["settings"].chroma_client_auth_provider == "provider"
assert params["settings"].chroma_client_auth_credentials == "credentials"
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(chroma_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(chroma_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(chroma_module.redis_client, "set", MagicMock())
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector.create_collection("collection_1")
vector._client.get_or_create_collection.assert_called_once_with("collection_1")
chroma_module.redis_client.set.assert_called_once()
def test_create_with_empty_texts_is_noop(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector.create([], [])
vector._client.get_or_create_collection.assert_not_called()
def test_create_with_texts_creates_collection_and_upserts(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
docs = [Document(page_content="hello", metadata={"doc_id": "d1", "document_id": "doc-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._client.get_or_create_collection.assert_called()
vector._client.collection.upsert.assert_called_once()
def test_delete_methods_and_text_exists(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector.delete_by_ids([])
vector._client.collection.delete.assert_not_called()
vector.delete_by_ids(["id-1"])
vector._client.collection.delete.assert_called_with(ids=["id-1"])
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.collection.delete.assert_called_with(where={"document_id": {"$eq": "doc-1"}})
vector._client.collection.get.return_value = {"ids": ["id-1"]}
assert vector.text_exists("id-1") is True
vector._client.collection.get.return_value = {}
assert vector.text_exists("id-2") is False
vector.delete()
vector._client.delete_collection.assert_called_once_with("collection_1")
def test_search_by_vector_handles_empty_results(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector._client.collection.query.return_value = {"ids": [], "documents": [], "metadatas": [], "distances": []}
assert vector.search_by_vector([0.1, 0.2], top_k=2) == []
def test_search_by_vector_applies_score_threshold_and_sorting(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
vector._client.collection.query.return_value = {
"ids": [["id-1", "id-2"]],
"documents": [["doc high", "doc low"]],
"metadatas": [[{"doc_id": "id-1"}, {"doc_id": "id-2"}]],
"distances": [[0.1, 0.8]],
}
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].page_content == "doc high"
assert docs[0].metadata["score"] == 0.9
def test_search_by_full_text_returns_empty_list(chroma_module):
vector = chroma_module.ChromaVector(
collection_name="collection_1",
config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"),
)
assert vector.search_by_full_text("query") == []
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch):
factory = chroma_module.ChromaVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(chroma_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_HOST", "localhost")
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_PORT", 8000)
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_TENANT", None)
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_DATABASE", None)
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_PROVIDER", None)
monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_CREDENTIALS", None)
with patch.object(chroma_module, "ChromaVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,927 @@
import importlib
import queue
import sys
import types
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_clickzetta_module():
clickzetta = types.ModuleType("clickzetta")
class _FakeCursor:
def __init__(self):
self.execute = MagicMock()
self.executemany = MagicMock()
self.fetchall = MagicMock(return_value=[])
self.fetchone = MagicMock(return_value=(0,))
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class _FakeConnection:
def __init__(self):
self.cursor_obj = _FakeCursor()
def cursor(self):
return self.cursor_obj
def close(self):
return None
def connect(**_kwargs):
return _FakeConnection()
clickzetta.connect = connect
return clickzetta
@pytest.fixture
def clickzetta_module(monkeypatch):
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module
return importlib.reload(module)
def _config(module):
return module.ClickzettaConfig(
username="username",
password="password",
instance="instance",
service="service",
workspace="workspace",
vcluster="cluster",
schema_name="dify",
)
@pytest.mark.parametrize(
("field", "error_message"),
[
("username", "CLICKZETTA_USERNAME"),
("password", "CLICKZETTA_PASSWORD"),
("instance", "CLICKZETTA_INSTANCE"),
("service", "CLICKZETTA_SERVICE"),
("workspace", "CLICKZETTA_WORKSPACE"),
("vcluster", "CLICKZETTA_VCLUSTER"),
("schema_name", "CLICKZETTA_SCHEMA"),
],
)
def test_clickzetta_config_validation(clickzetta_module, field, error_message):
values = _config(clickzetta_module).model_dump()
values[field] = ""
with pytest.raises(ValueError, match=error_message):
clickzetta_module.ClickzettaConfig.model_validate(values)
def test_parse_metadata_handles_valid_double_encoded_and_invalid_json(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
parsed = vector._parse_metadata('{"document_id":"doc-1"}', "row-1")
assert parsed["doc_id"] == "row-1"
assert parsed["document_id"] == "doc-1"
parsed_double = vector._parse_metadata('"{\\"document_id\\": \\"doc-2\\"}"', "row-2")
assert parsed_double["doc_id"] == "row-2"
assert parsed_double["document_id"] == "doc-2"
parsed_fallback = vector._parse_metadata("not-json", "row-3")
assert parsed_fallback["doc_id"] == "row-3"
assert parsed_fallback["document_id"] == "row-3"
def test_safe_doc_id_and_vector_format_helpers(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
assert vector._format_vector_simple([0.1, 0.2, 0.3]) == "0.1,0.2,0.3"
assert vector._safe_doc_id("abc-123_DEF") == "abc-123_DEF"
assert vector._safe_doc_id("ab c;\n") == "abc"
assert len(vector._safe_doc_id("a" * 300)) == 255
def test_table_exists_returns_false_for_not_found_and_other_exceptions(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
@contextmanager
def _ctx_not_found():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.execute.side_effect = RuntimeError("CZLH-42000 table or view not found")
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_not_found
assert vector._table_exists() is False
@contextmanager
def _ctx_other_error():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.execute.side_effect = RuntimeError("permission denied")
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_other_error
assert vector._table_exists() is False
def test_text_exists_handles_missing_table_and_existing_rows(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=False)
assert vector.text_exists("doc-1") is False
vector._table_exists = MagicMock(return_value=True)
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchone.return_value = (1,)
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
assert vector.text_exists("doc-1") is True
def test_delete_by_ids_and_delete_by_metadata_field_short_circuit(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._execute_write = MagicMock()
vector.delete_by_ids([])
vector._execute_write.assert_not_called()
vector._table_exists = MagicMock(return_value=False)
vector.delete_by_ids(["doc-1"])
vector._execute_write.assert_not_called()
vector.delete_by_metadata_field("document_id", "doc-1")
vector._execute_write.assert_not_called()
def test_search_short_circuit_behaviors(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=False)
assert vector.search_by_vector([0.1, 0.2], top_k=2) == []
vector._config.enable_inverted_index = False
assert vector.search_by_full_text("query", top_k=2) == []
def test_search_by_like_returns_documents_with_default_score(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=True)
vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"})
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}')]
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
docs = vector._search_by_like("query", top_k=3, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].page_content == "content"
assert docs[0].metadata["score"] == 0.5
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
factory = clickzetta_module.ClickzettaVectorFactory()
dataset = SimpleNamespace(id="dataset-1")
monkeypatch.setattr(clickzetta_module.Dataset, "gen_collection_name_by_id", lambda _id: "COLLECTION")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_USERNAME", "username")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_PASSWORD", "password")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_INSTANCE", "instance")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SERVICE", "service")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_WORKSPACE", "workspace")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VCLUSTER", "cluster")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SCHEMA", "dify")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_BATCH_SIZE", 10)
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ENABLE_INVERTED_INDEX", True)
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_TYPE", "chinese")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_MODE", "smart")
monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VECTOR_DISTANCE_FUNCTION", "cosine_distance")
with patch.object(clickzetta_module, "ClickzettaVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "collection"
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch):
clickzetta_module.ClickzettaConnectionPool._instance = None
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool_1 = clickzetta_module.ClickzettaConnectionPool.get_instance()
pool_2 = clickzetta_module.ClickzettaConnectionPool.get_instance()
key = pool_1._get_config_key(_config(clickzetta_module))
assert pool_1 is pool_2
assert "username:instance:service:workspace:cluster:dify" in key
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
connection = MagicMock()
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
monkeypatch.setattr(
clickzetta_module.clickzetta, "connect", MagicMock(side_effect=[RuntimeError("boom"), connection])
)
pool._configure_connection = MagicMock()
created = pool._create_connection(config)
assert created is connection
assert clickzetta_module.clickzetta.connect.call_count == 2
pool._configure_connection.assert_called_once_with(connection)
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
monkeypatch.setattr(clickzetta_module.clickzetta, "connect", MagicMock(side_effect=RuntimeError("boom")))
with pytest.raises(RuntimeError, match="boom"):
pool._create_connection(config)
def test_connection_pool_configure_and_validate_connection(clickzetta_module):
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection = MagicMock()
connection.cursor.return_value = cursor
pool._configure_connection(connection)
assert cursor.execute.call_count >= 2
assert pool._is_connection_valid(connection) is True
bad_connection = MagicMock()
bad_connection.cursor.side_effect = RuntimeError("bad connection")
assert pool._is_connection_valid(bad_connection) is False
monkeypatch.undo()
def test_connection_pool_configure_connection_swallows_errors(clickzetta_module):
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
connection = MagicMock()
connection.cursor.side_effect = RuntimeError("cannot configure")
pool._configure_connection(connection)
monkeypatch.undo()
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
key = pool._get_config_key(config)
created_connection = MagicMock()
pool._create_connection = MagicMock(return_value=created_connection)
first = pool.get_connection(config)
assert first is created_connection
reusable_connection = MagicMock()
pool._pools[key] = [(reusable_connection, clickzetta_module.time.time())]
pool._is_connection_valid = MagicMock(return_value=True)
reused = pool.get_connection(config)
assert reused is reusable_connection
expired_connection = MagicMock()
pool._pools[key] = [(expired_connection, 0.0)]
pool._is_connection_valid = MagicMock(return_value=False)
monkeypatch.setattr(clickzetta_module.time, "time", MagicMock(return_value=1000.0))
pool.get_connection(config)
expired_connection.close.assert_called_once()
random_connection = MagicMock()
pool._is_connection_valid = MagicMock(return_value=True)
pool.return_connection(config, random_connection)
assert len(pool._pools[key]) == 1
pool._pools[key] = [(MagicMock(), 0.0), (MagicMock(), 1000.0)]
pool._connection_timeout = 10
pool._cleanup_expired_connections()
assert len(pool._pools[key]) == 1
unknown_pool = MagicMock()
pool.return_connection(_config(clickzetta_module).model_copy(update={"workspace": "other"}), unknown_pool)
unknown_pool.close.assert_called_once()
pool.shutdown()
assert pool._shutdown is True
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._shutdown = False
pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True))
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
class _Thread:
def __init__(self, target, daemon):
self._target = target
self.daemon = daemon
self.started = False
def start(self):
self.started = True
self._target()
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
pool._start_cleanup_thread()
assert pool._cleanup_thread.started is True
pool._cleanup_expired_connections.assert_called_once()
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch):
pool = MagicMock()
pool.get_connection.return_value = "conn"
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool))
monkeypatch.setattr(clickzetta_module.ClickzettaVector, "_init_write_queue", MagicMock())
vector = clickzetta_module.ClickzettaVector("My-Collection", _config(clickzetta_module))
assert vector._table_name == "my_collection"
assert vector._get_connection() == "conn"
vector._return_connection("conn")
pool.return_connection.assert_called_with(vector._config, "conn")
with vector.get_connection_context() as conn:
assert conn == "conn"
assert pool.return_connection.call_count >= 2
assert vector.get_type() == "clickzetta"
assert vector._ensure_connection() == "conn"
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch):
class _Thread:
def __init__(self, target, daemon):
self.target = target
self.daemon = daemon
self.started = 0
def start(self):
self.started += 1
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
clickzetta_module.ClickzettaVector._write_queue = None
clickzetta_module.ClickzettaVector._write_thread = None
clickzetta_module.ClickzettaVector._shutdown = False
clickzetta_module.ClickzettaVector._init_write_queue()
clickzetta_module.ClickzettaVector._init_write_queue()
assert clickzetta_module.ClickzettaVector._write_thread.started == 1
result_queue_ok = queue.Queue()
result_queue_fail = queue.Queue()
clickzetta_module.ClickzettaVector._write_queue = queue.Queue()
clickzetta_module.ClickzettaVector._shutdown = False
clickzetta_module.ClickzettaVector._write_queue.put((lambda x: x + 1, (1,), {}, result_queue_ok))
clickzetta_module.ClickzettaVector._write_queue.put(
(lambda: (_ for _ in ()).throw(RuntimeError("worker error")), (), {}, result_queue_fail)
)
clickzetta_module.ClickzettaVector._write_queue.put(None)
clickzetta_module.ClickzettaVector._write_worker()
assert result_queue_ok.get() == (True, 2)
failed = result_queue_fail.get()
assert failed[0] is False
assert isinstance(failed[1], RuntimeError)
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
clickzetta_module.ClickzettaVector._write_queue = None
with pytest.raises(RuntimeError, match="Write queue not initialized"):
vector._execute_write(lambda: None)
class _ImmediateSuccessQueue:
def put(self, task):
func, args, kwargs, result_q = task
result_q.put((True, func(*args, **kwargs)))
clickzetta_module.ClickzettaVector._write_queue = _ImmediateSuccessQueue()
assert vector._execute_write(lambda x: x * 2, 3) == 6
class _ImmediateFailQueue:
def put(self, task):
_, _, _, result_q = task
result_q.put((False, ValueError("write failed")))
clickzetta_module.ClickzettaVector._write_queue = _ImmediateFailQueue()
with pytest.raises(ValueError, match="write failed"):
vector._execute_write(lambda: None)
def test_table_exists_true_and_create_invokes_write_and_add_texts(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
@contextmanager
def _ctx_exists():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_exists
assert vector._table_exists() is True
vector._execute_write = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="content", metadata={"doc_id": "d1"})]
vector.create(docs, [[0.1, 0.2]])
vector._execute_write.assert_called_once()
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_table_and_indexes_paths(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._create_vector_index = MagicMock()
vector._create_inverted_index = MagicMock()
vector._table_exists = MagicMock(return_value=True)
vector._create_table_and_indexes([[0.1, 0.2]])
vector._create_vector_index.assert_not_called()
vector._table_exists = MagicMock(return_value=False)
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
vector._create_table_and_indexes([[0.1, 0.2, 0.3]])
vector._create_vector_index.assert_called_once()
vector._create_inverted_index.assert_called_once()
vector._config.enable_inverted_index = False
vector._create_vector_index.reset_mock()
vector._create_inverted_index.reset_mock()
vector._create_table_and_indexes([])
vector._create_vector_index.assert_called_once()
vector._create_inverted_index.assert_not_called()
def test_create_vector_index_branches(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
cursor = MagicMock()
cursor.fetchall.return_value = [("idx_table_vector", "embedding_vector")]
vector._create_vector_index(cursor)
assert cursor.execute.call_count == 1
cursor.reset_mock()
cursor.execute.side_effect = [RuntimeError("show index failed"), None]
vector._create_vector_index(cursor)
assert cursor.execute.call_count == 2
cursor.reset_mock()
cursor.execute.side_effect = [None, RuntimeError("already exists")]
cursor.fetchall.return_value = []
vector._create_vector_index(cursor)
cursor.reset_mock()
cursor.execute.side_effect = [None, RuntimeError("unexpected")]
cursor.fetchall.return_value = []
with pytest.raises(RuntimeError, match="unexpected"):
vector._create_vector_index(cursor)
def test_create_inverted_index_branches(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
cursor = MagicMock()
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
vector._create_inverted_index(cursor)
assert cursor.execute.call_count == 1
cursor.reset_mock()
cursor.execute.side_effect = [RuntimeError("show failed"), None]
vector._create_inverted_index(cursor)
assert cursor.execute.call_count == 2
cursor.reset_mock()
cursor.execute.side_effect = [
None,
RuntimeError("already has index"),
None,
]
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
vector._create_inverted_index(cursor)
cursor.reset_mock()
cursor.execute.side_effect = [None, RuntimeError("other create failure")]
cursor.fetchall.return_value = []
vector._create_inverted_index(cursor)
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._config.batch_size = 2
vector._table_name = "table_1"
vector._execute_write = MagicMock()
vector._safe_doc_id = MagicMock(side_effect=lambda doc_id: str(doc_id))
docs = [
Document(page_content="doc-1", metadata={"doc_id": "id-1"}),
Document(page_content="doc-2", metadata={"doc_id": "id-2"}),
Document(page_content="doc-3", metadata={"doc_id": "id-3"}),
]
vectors = [[0.1], [0.2], [0.3]]
vector.add_texts([], [])
vector._execute_write.assert_not_called()
added_ids = vector.add_texts(docs, vectors)
assert added_ids == ["id-1", "id-2", "id-3"]
assert vector._execute_write.call_count == 2
assert vector._execute_write.call_args_list[0].args == (
vector._insert_batch,
docs[:2],
vectors[:2],
["id-1", "id-2"],
0,
2,
2,
)
assert vector._execute_write.call_args_list[1].args == (
vector._insert_batch,
docs[2:],
vectors[2:],
["id-3"],
2,
2,
2,
)
vector._insert_batch([], [], [], 0, 2, 1)
vector._insert_batch(docs[:1], vectors, ["id-1"], 0, 2, 1)
bad_doc = Document(page_content="doc-bad", metadata={"doc_id": "id-bad", "bad": {1}})
good_doc = Document(page_content="doc-good", metadata={"doc_id": "id-good"})
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
vector._insert_batch(
[bad_doc, good_doc],
[[0.1, 0.2], [0.3, 0.4]],
["id-bad", "id-good"],
0,
2,
1,
)
@contextmanager
def _ctx_error():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.executemany.side_effect = RuntimeError("insert failed")
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_error
with pytest.raises(RuntimeError, match="insert failed"):
vector._insert_batch([good_doc], [[0.1, 0.2]], ["id-good"], 0, 1, 1)
monkeypatch.setattr(clickzetta_module.uuid, "uuid4", lambda: "generated-id")
vector._safe_doc_id = clickzetta_module.ClickzettaVector._safe_doc_id.__get__(vector)
assert vector._safe_doc_id("") == "generated-id"
assert vector._safe_doc_id("!!!") == "generated-id"
def test_delete_by_ids_and_metadata_impl_paths(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._execute_write = MagicMock()
vector._table_exists = MagicMock(return_value=True)
vector.delete_by_ids(["id-1", "id-2"])
vector._execute_write.assert_called_once()
assert vector._execute_write.call_args.args[0] == vector._delete_by_ids_impl
vector._execute_write.reset_mock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector._execute_write.assert_called_once()
assert vector._execute_write.call_args.args[0] == vector._delete_by_metadata_field_impl
vector._safe_doc_id = MagicMock(side_effect=lambda x: x)
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
vector._delete_by_ids_impl(["id-1", "id-2"])
vector._delete_by_metadata_field_impl("document_id", "doc-1")
def test_search_by_vector_covers_cosine_and_l2_paths(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._config.vector_distance_function = "cosine_distance"
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=True)
vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"})
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}', 0.2)]
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
cosine_docs = vector.search_by_vector(
[0.1, 0.2], top_k=3, score_threshold=0.5, document_ids_filter=["doc-1"], filter={"k": "v"}
)
assert cosine_docs[0].metadata["score"] == pytest.approx(0.9)
vector._config.vector_distance_function = "l2_distance"
l2_docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5)
assert l2_docs[0].metadata["score"] == pytest.approx(1 / 1.2)
def test_search_by_full_text_success_and_fallback(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=True)
@contextmanager
def _ctx_success():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchall.return_value = [
("seg-1", "content-1", '"{\\"document_id\\":\\"doc-1\\"}"'),
("seg-2", "content-2", "invalid-json"),
]
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_success
docs = vector.search_by_full_text("search'value", top_k=2, document_ids_filter=["doc-1"], filter={"a": 1})
assert len(docs) == 2
assert docs[0].metadata["score"] == 1.0
assert docs[1].metadata["doc_id"] == "seg-2"
@contextmanager
def _ctx_failure():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.execute.side_effect = RuntimeError("full text failed")
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx_failure
vector._search_by_like = MagicMock(return_value=[Document(page_content="fallback", metadata={"score": 0.5})])
fallback_docs = vector.search_by_full_text("query", top_k=1)
assert fallback_docs == vector._search_by_like.return_value
def test_search_by_like_missing_table_and_delete_table(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=False)
assert vector._search_by_like("query", top_k=1) == []
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
vector.delete()
def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._pools = {}
pool._pool_locks = {}
pool._max_pool_size = 1
pool._connection_timeout = 10
pool._lock = clickzetta_module.threading.Lock()
pool._shutdown = False
config = _config(clickzetta_module)
key = pool._get_config_key(config)
pool._pools[key] = [(MagicMock(), 1.0)]
pool._pool_locks[key] = clickzetta_module.threading.Lock()
pool._is_connection_valid = MagicMock(return_value=False)
conn = MagicMock()
pool.return_connection(config, conn)
conn.close.assert_called_once()
pool._pools["missing-lock-key"] = [(MagicMock(), 0.0)]
pool._cleanup_expired_connections()
pool.shutdown()
assert pool._shutdown is True
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._shutdown = False
def _cleanup_then_fail():
pool._shutdown = True
raise RuntimeError("cleanup failed")
pool._cleanup_expired_connections = MagicMock(side_effect=_cleanup_then_fail)
monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None)
class _Thread:
def __init__(self, target, daemon):
self._target = target
self.daemon = daemon
def start(self):
self._target()
monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread)
pool._start_cleanup_thread()
pool._cleanup_expired_connections.assert_called_once()
def test_clickzetta_parse_metadata_and_write_worker_additional_branches(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
parsed_non_dict = vector._parse_metadata("[1,2,3]", "row-1")
assert parsed_non_dict["doc_id"] == "row-1"
assert parsed_non_dict["document_id"] == "row-1"
parsed_none = vector._parse_metadata(None, "row-2")
assert parsed_none["doc_id"] == "row-2"
assert parsed_none["document_id"] == "row-2"
clickzetta_module.ClickzettaVector._shutdown = False
clickzetta_module.ClickzettaVector._write_queue = None
clickzetta_module.ClickzettaVector._write_worker()
class _BadQueue:
def get(self, timeout):
clickzetta_module.ClickzettaVector._shutdown = True
raise RuntimeError("queue failed")
clickzetta_module.ClickzettaVector._shutdown = False
clickzetta_module.ClickzettaVector._write_queue = _BadQueue()
clickzetta_module.ClickzettaVector._write_worker()
def test_clickzetta_inverted_index_existing_and_insert_non_dict_metadata(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._table_name = "table_1"
cursor = MagicMock()
cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")]
cursor.execute.side_effect = [
None,
RuntimeError("already has index with the same type cannot create inverted index"),
None,
]
vector._create_inverted_index(cursor)
vector._safe_doc_id = MagicMock(side_effect=lambda value: str(value))
@contextmanager
def _ctx():
connection = MagicMock()
cursor_obj = MagicMock()
cursor_obj.__enter__.return_value = cursor_obj
cursor_obj.__exit__.return_value = None
connection.cursor.return_value = cursor_obj
yield connection
vector.get_connection_context = _ctx
vector._insert_batch(
[SimpleNamespace(page_content="content", metadata="not-a-dict")],
[[0.1, 0.2]],
["doc-1"],
0,
1,
1,
)
def test_clickzetta_full_text_table_missing_and_non_dict_metadata(clickzetta_module):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._config.enable_inverted_index = True
vector._table_name = "table_1"
vector._table_exists = MagicMock(return_value=False)
assert vector.search_by_full_text("query") == []
vector._table_exists = MagicMock(return_value=True)
@contextmanager
def _ctx():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__.return_value = cursor
cursor.__exit__.return_value = None
cursor.fetchall.return_value = [
("seg-1", "content-1", "[1,2,3]"),
("seg-2", "content-2", None),
]
connection.cursor.return_value = cursor
yield connection
vector.get_connection_context = _ctx
docs = vector.search_by_full_text("query")
assert len(docs) == 2
assert docs[0].metadata["doc_id"] == "seg-1"
assert docs[1].metadata["doc_id"] == "seg-2"

View File

@ -0,0 +1,364 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_couchbase_modules():
couchbase = types.ModuleType("couchbase")
couchbase_auth = types.ModuleType("couchbase.auth")
couchbase_cluster = types.ModuleType("couchbase.cluster")
couchbase_management = types.ModuleType("couchbase.management")
couchbase_management_search = types.ModuleType("couchbase.management.search")
couchbase_options = types.ModuleType("couchbase.options")
couchbase_vector = types.ModuleType("couchbase.vector_search")
couchbase_search = types.ModuleType("couchbase.search")
class PasswordAuthenticator:
def __init__(self, user, password):
self.user = user
self.password = password
class ClusterOptions:
def __init__(self, auth):
self.auth = auth
class SearchOptions:
def __init__(self, **kwargs):
self.kwargs = kwargs
class VectorQuery:
def __init__(self, field, vector, top_k):
self.field = field
self.vector = vector
self.top_k = top_k
class VectorSearch:
@staticmethod
def from_vector_query(vector_query):
return {"vector_query": vector_query}
class QueryStringQuery:
def __init__(self, query):
self.query = query
class SearchRequest:
@staticmethod
def create(payload):
return {"payload": payload}
class SearchIndex:
def __init__(self, name, params, source_name):
self.name = name
self.params = params
self.source_name = source_name
class _QueryResult:
def __init__(self, rows=None):
self._rows = rows or []
def execute(self):
return self
def __iter__(self):
return iter(self._rows)
class _SearchIter:
def __init__(self, rows=None):
self._rows = rows or []
def rows(self):
return self._rows
class _Collection:
def __init__(self):
self.upsert = MagicMock(return_value=True)
class _SearchIndexManager:
def __init__(self):
self.upsert_index = MagicMock()
class _Scope:
def __init__(self):
self._collection = _Collection()
self._search_index_manager = _SearchIndexManager()
self.search = MagicMock(return_value=_SearchIter())
def collection(self, _name):
return self._collection
def search_indexes(self):
return self._search_index_manager
class _CollectionManager:
def __init__(self):
self.create_collection = MagicMock()
self.drop_collection = MagicMock()
self.get_all_scopes = MagicMock(return_value=[])
class _Bucket:
def __init__(self):
self._scope = _Scope()
self._collections = _CollectionManager()
def scope(self, _scope_name):
return self._scope
def collections(self):
return self._collections
class Cluster:
def __init__(self, connection_string, options):
self.connection_string = connection_string
self.options = options
self._bucket = _Bucket()
self.wait_until_ready = MagicMock()
self.query = MagicMock(return_value=_QueryResult())
def bucket(self, _name):
return self._bucket
couchbase_auth.PasswordAuthenticator = PasswordAuthenticator
couchbase_cluster.Cluster = Cluster
couchbase_management_search.SearchIndex = SearchIndex
couchbase_options.ClusterOptions = ClusterOptions
couchbase_options.SearchOptions = SearchOptions
couchbase_vector.VectorQuery = VectorQuery
couchbase_vector.VectorSearch = VectorSearch
couchbase_search.QueryStringQuery = QueryStringQuery
couchbase_search.SearchRequest = SearchRequest
couchbase.search = couchbase_search
couchbase.management = couchbase_management
return {
"couchbase": couchbase,
"couchbase.auth": couchbase_auth,
"couchbase.cluster": couchbase_cluster,
"couchbase.management": couchbase_management,
"couchbase.management.search": couchbase_management_search,
"couchbase.options": couchbase_options,
"couchbase.vector_search": couchbase_vector,
"couchbase.search": couchbase_search,
}
@pytest.fixture
def couchbase_module(monkeypatch):
for name, module in _build_fake_couchbase_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.couchbase.couchbase_vector as module
return importlib.reload(module)
def _config(module):
return module.CouchbaseConfig(
connection_string="couchbase://localhost",
user="user",
password="pass",
bucket_name="bucket",
scope_name="scope",
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("connection_string", "", "CONNECTION_STRING is required"),
("user", "", "COUCHBASE_USER is required"),
("password", "", "COUCHBASE_PASSWORD is required"),
("bucket_name", "", "COUCHBASE_PASSWORD is required"),
("scope_name", "", "COUCHBASE_SCOPE_NAME is required"),
],
)
def test_couchbase_config_validation(couchbase_module, field, value, message):
values = _config(couchbase_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
couchbase_module.CouchbaseConfig.model_validate(values)
def test_init_sets_cluster_handles(couchbase_module):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
assert vector._bucket_name == "bucket"
assert vector._scope_name == "scope"
vector._cluster.wait_until_ready.assert_called_once()
def test_create_and_create_collection_branches(couchbase_module, monkeypatch):
vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector)
vector._collection_name = "collection_1"
vector._client_config = _config(couchbase_module)
vector._scope_name = "scope"
vector._bucket_name = "bucket"
vector._bucket = MagicMock()
vector._scope = MagicMock()
vector._collection_exists = MagicMock(return_value=False)
vector.add_texts = MagicMock()
monkeypatch.setattr(couchbase_module.uuid, "uuid4", lambda: "a-b-c")
vector._create_collection = MagicMock()
docs = [Document(page_content="text", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(uuid="abc", vector_length=2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(couchbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(couchbase_module.redis_client, "set", MagicMock())
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(vector_length=2, uuid="uuid-1")
vector._bucket.collections().create_collection.assert_not_called()
monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=None))
vector._collection_exists = MagicMock(return_value=True)
vector._create_collection(vector_length=2, uuid="uuid-2")
vector._bucket.collections().create_collection.assert_not_called()
vector._collection_exists = MagicMock(return_value=False)
vector._create_collection(vector_length=3, uuid="uuid-3")
vector._bucket.collections().create_collection.assert_called_once_with("scope", "collection_1")
vector._scope.search_indexes().upsert_index.assert_called_once()
search_index = vector._scope.search_indexes().upsert_index.call_args.args[0]
assert search_index.name == "collection_1_search"
assert (
search_index.params["mapping"]["types"]["scope.collection_1"]["properties"]["embedding"]["fields"][0]["dims"]
== 3
)
couchbase_module.redis_client.set.assert_called_once()
def test_collection_exists_get_type_and_add_texts(couchbase_module):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="collection_1")])
vector._bucket.collections().get_all_scopes.return_value = [scope_obj]
assert vector._collection_exists("collection_1") is True
scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="other")])
vector._bucket.collections().get_all_scopes.return_value = [scope_obj]
assert vector._collection_exists("collection_1") is False
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert vector._scope.collection("collection_1").upsert.call_count == 2
assert vector.get_type() == couchbase_module.VectorType.COUCHBASE
def test_query_delete_helpers(couchbase_module):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([{"count": 2}]))
assert vector.text_exists("id-1") is True
vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([]))
assert vector.text_exists("id-2") is False
query_result = MagicMock()
query_result.execute.return_value = None
vector._cluster.query.return_value = query_result
vector.delete_by_ids(["id-1", "id-2"])
vector.delete_by_document_id("id-1")
vector.delete_by_metadata_field("document_id", "doc-1")
assert vector._cluster.query.call_count >= 3
vector._cluster.query.side_effect = RuntimeError("delete failed")
vector.delete_by_ids(["id-3"])
def test_search_methods_and_format_metadata(couchbase_module):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
row_1 = SimpleNamespace(fields={"text": "doc-a", "metadata.document_id": "d-1"}, score=0.9)
row_2 = SimpleNamespace(fields={"text": "doc-b", "metadata.document_id": "d-2"}, score=0.3)
vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_1, row_2])
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].page_content == "doc-a"
assert docs[0].metadata["document_id"] == "d-1"
assert docs[0].metadata["score"] == pytest.approx(0.9)
vector._scope.search.side_effect = RuntimeError("search error")
with pytest.raises(ValueError, match="Search failed"):
vector.search_by_vector([0.1], top_k=1)
vector._scope.search.side_effect = None
row_3 = SimpleNamespace(fields={"text": "full-text", "metadata.doc_id": "x"}, score=0.7)
vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_3])
docs = vector.search_by_full_text("hello", top_k=1)
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == "x"
vector._scope.search.side_effect = RuntimeError("full text failed")
with pytest.raises(ValueError, match="Search failed"):
vector.search_by_full_text("hello", top_k=1)
assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2}
def test_delete_collection_and_factory(couchbase_module, monkeypatch):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
scopes = [
SimpleNamespace(collections=[SimpleNamespace(name="other")]),
SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]),
]
vector._bucket.collections().get_all_scopes.return_value = scopes
vector.delete()
vector._bucket.collections().drop_collection.assert_called_once_with("_default", "collection_1")
factory = couchbase_module.CouchbaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(couchbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(
couchbase_module,
"current_app",
SimpleNamespace(
config={
"COUCHBASE_CONNECTION_STRING": "couchbase://localhost",
"COUCHBASE_USER": "user",
"COUCHBASE_PASSWORD": "pass",
"COUCHBASE_BUCKET_NAME": "bucket",
"COUCHBASE_SCOPE_NAME": "scope",
}
),
)
with patch.object(couchbase_module, "CouchbaseVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,121 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
def _build_fake_elasticsearch_modules():
elasticsearch = types.ModuleType("elasticsearch")
class ConnectionError(Exception):
pass
class Elasticsearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.ping = MagicMock(return_value=True)
self.info = MagicMock(return_value={"version": {"number": "8.12.0"}})
self.indices = SimpleNamespace(
refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock()
)
elasticsearch.Elasticsearch = Elasticsearch
elasticsearch.ConnectionError = ConnectionError
return {"elasticsearch": elasticsearch}
@pytest.fixture
def elasticsearch_ja_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module
importlib.reload(base_module)
return importlib.reload(ja_module)
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock())
vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector)
vector._collection_name = "test"
vector._client = MagicMock()
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
elasticsearch_ja_module.redis_client.set.assert_not_called()
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock())
vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector)
vector._collection_name = "test"
vector._client = MagicMock()
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2, 0.3]], [{}])
vector._client.indices.create.assert_called_once()
kwargs = vector._client.indices.create.call_args.kwargs
assert kwargs["index"] == "test"
assert kwargs["mappings"]["properties"][elasticsearch_ja_module.Field.VECTOR]["dims"] == 3
elasticsearch_ja_module.redis_client.set.assert_called_once()
vector._client.indices.create.reset_mock()
elasticsearch_ja_module.redis_client.set.reset_mock()
vector._client.indices.exists.return_value = True
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
elasticsearch_ja_module.redis_client.set.assert_called_once()
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch):
factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(elasticsearch_ja_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(
elasticsearch_ja_module,
"current_app",
SimpleNamespace(
config={
"ELASTICSEARCH_HOST": "localhost",
"ELASTICSEARCH_PORT": 9200,
"ELASTICSEARCH_USERNAME": "elastic",
"ELASTICSEARCH_PASSWORD": "secret",
}
),
)
with patch.object(elasticsearch_ja_module, "ElasticSearchJaVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["index_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["index_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,405 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_elasticsearch_modules():
elasticsearch = types.ModuleType("elasticsearch")
class ConnectionError(Exception):
pass
class Elasticsearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.ping = MagicMock(return_value=True)
self.info = MagicMock(return_value={"version": {"number": "8.12.0-SNAPSHOT"}})
self.index = MagicMock()
self.exists = MagicMock(return_value=False)
self.delete = MagicMock()
self.search = MagicMock(return_value={"hits": {"hits": []}})
self.indices = SimpleNamespace(
refresh=MagicMock(),
delete=MagicMock(),
exists=MagicMock(return_value=False),
create=MagicMock(),
)
elasticsearch.Elasticsearch = Elasticsearch
elasticsearch.ConnectionError = ConnectionError
return {"elasticsearch": elasticsearch}
@pytest.fixture
def elasticsearch_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module
return importlib.reload(module)
def _regular_config(module, **overrides):
values = {
"host": "localhost",
"port": 9200,
"username": "elastic",
"password": "secret",
"verify_certs": False,
"request_timeout": 10,
"retry_on_timeout": True,
"max_retries": 3,
}
values.update(overrides)
return module.ElasticSearchConfig.model_validate(values)
def _cloud_config(module, **overrides):
values = {
"use_cloud": True,
"cloud_url": "https://cloud.example:9243",
"api_key": "api-key",
"verify_certs": True,
"ca_certs": "/tmp/ca.pem",
"request_timeout": 10,
"retry_on_timeout": True,
"max_retries": 3,
}
values.update(overrides)
return module.ElasticSearchConfig.model_validate(values)
@pytest.mark.parametrize(
("values", "message"),
[
({"use_cloud": True, "cloud_url": None, "api_key": "x"}, "cloud_url is required"),
({"use_cloud": True, "cloud_url": "https://cloud", "api_key": None}, "api_key is required"),
({"host": None, "port": 9200, "username": "u", "password": "p"}, "HOST is required"),
({"host": "h", "port": None, "username": "u", "password": "p"}, "PORT is required"),
({"host": "h", "port": 9200, "username": None, "password": "p"}, "USERNAME is required"),
({"host": "h", "port": 9200, "username": "u", "password": None}, "PASSWORD is required"),
],
)
def test_elasticsearch_config_validation(elasticsearch_module, values, message):
with pytest.raises(ValidationError, match=message):
elasticsearch_module.ElasticSearchConfig.model_validate(values)
def test_init_client_cloud_configuration(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
client = MagicMock()
client.ping.return_value = True
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
result = vector._init_client(_cloud_config(elasticsearch_module))
assert result is client
kwargs = es_cls.call_args.kwargs
assert kwargs["hosts"] == ["https://cloud.example:9243"]
assert kwargs["api_key"] == "api-key"
assert kwargs["verify_certs"] is True
assert kwargs["ca_certs"] == "/tmp/ca.pem"
def test_init_client_regular_https_and_http_fallback(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
client = MagicMock()
client.ping.return_value = True
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
vector._init_client(
_regular_config(
elasticsearch_module,
host="https://es.example",
port=9443,
verify_certs=True,
ca_certs="/tmp/ca.pem",
)
)
kwargs = es_cls.call_args.kwargs
assert kwargs["hosts"] == ["https://es.example:9443"]
assert kwargs["verify_certs"] is True
assert kwargs["ca_certs"] == "/tmp/ca.pem"
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls:
vector._init_client(_regular_config(elasticsearch_module, host="es.internal", port=9200))
kwargs = es_cls.call_args.kwargs
assert kwargs["hosts"] == ["http://es.internal:9200"]
assert "verify_certs" not in kwargs
def test_init_client_connection_failures(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
client = MagicMock()
client.ping.return_value = False
with patch.object(elasticsearch_module, "Elasticsearch", return_value=client):
with pytest.raises(ConnectionError, match="Failed to connect"):
vector._init_client(_regular_config(elasticsearch_module))
with patch.object(
elasticsearch_module,
"Elasticsearch",
side_effect=elasticsearch_module.ElasticsearchConnectionError("boom"),
):
with pytest.raises(ConnectionError, match="Vector database connection error"):
vector._init_client(_regular_config(elasticsearch_module))
with patch.object(elasticsearch_module, "Elasticsearch", side_effect=RuntimeError("oops")):
with pytest.raises(ConnectionError, match="initialization failed"):
vector._init_client(_regular_config(elasticsearch_module))
def test_init_get_version_and_check_version(elasticsearch_module):
with (
patch.object(elasticsearch_module.ElasticSearchVector, "_init_client", return_value=MagicMock()) as init_client,
patch.object(elasticsearch_module.ElasticSearchVector, "_get_version", return_value="8.10.0") as get_version,
patch.object(elasticsearch_module.ElasticSearchVector, "_check_version") as check_version,
):
vector = elasticsearch_module.ElasticSearchVector(
"collection_1", _regular_config(elasticsearch_module), attributes=["doc_id"]
)
init_client.assert_called_once()
get_version.assert_called_once()
check_version.assert_called_once()
assert vector._attributes == ["doc_id"]
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._client = MagicMock()
vector._client.info.return_value = {"version": {"number": "8.13.2-SNAPSHOT"}}
assert vector._get_version() == "8.13.2"
vector._version = "7.17.0"
with pytest.raises(ValueError, match="greater than 8.0.0"):
vector._check_version()
vector._version = "8.0.0"
vector._check_version()
def test_crud_methods_and_get_type(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.indices = SimpleNamespace(refresh=MagicMock(), delete=MagicMock())
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert vector._client.index.call_count == 2
vector._client.indices.refresh.assert_called_once_with(index="collection_1")
vector._client.exists.return_value = True
assert vector.text_exists("id-1") is True
vector.delete_by_ids([])
vector._client.delete.assert_not_called()
vector.delete_by_ids(["id-1", "id-2"])
assert vector._client.delete.call_count == 2
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}}
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "d1")
vector.delete_by_ids.assert_called_once_with(["id-1"])
vector.delete_by_ids.reset_mock()
vector._client.search.return_value = {"hits": {"hits": []}}
vector.delete_by_metadata_field("doc_id", "d2")
vector.delete_by_ids.assert_not_called()
vector.delete()
vector._client.indices.delete.assert_called_once_with(index="collection_1")
assert vector.get_type() == elasticsearch_module.VectorType.ELASTICSEARCH
def test_search_by_vector_and_full_text(elasticsearch_module):
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_score": 0.8,
"_source": {
elasticsearch_module.Field.CONTENT_KEY: "doc-a",
elasticsearch_module.Field.VECTOR: [0.1],
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"},
},
},
{
"_score": 0.2,
"_source": {
elasticsearch_module.Field.CONTENT_KEY: "doc-b",
elasticsearch_module.Field.VECTOR: [0.2],
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"},
},
},
]
}
}
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=2,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.8)
knn = vector._client.search.call_args.kwargs["knn"]
assert knn["k"] == 2
assert knn["num_candidates"] == 3
assert "filter" in knn
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
elasticsearch_module.Field.CONTENT_KEY: "text-hit",
elasticsearch_module.Field.VECTOR: [0.3],
elasticsearch_module.Field.METADATA_KEY: {"doc_id": "3"},
}
}
]
}
}
docs = vector.search_by_full_text("hello", top_k=3, document_ids_filter=["d-3"])
assert len(docs) == 1
assert docs[0].page_content == "text-hit"
query = vector._client.search.call_args.kwargs["query"]
assert "bool" in query
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(elasticsearch_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(elasticsearch_module.redis_client, "set", MagicMock())
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock())
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="a", metadata={"doc_id": "1"})]
vector.create(docs, [[0.1]])
vector.create_collection.assert_called_once()
vector.add_texts.assert_called_once_with(docs, [[0.1]])
vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock())
monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=None))
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_called_once()
mappings = vector._client.indices.create.call_args.kwargs["mappings"]
assert mappings["properties"][elasticsearch_module.Field.VECTOR]["dims"] == 2
elasticsearch_module.redis_client.set.assert_called_once()
vector._client.indices.create.reset_mock()
elasticsearch_module.redis_client.set.reset_mock()
vector._client.indices.exists.return_value = True
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
elasticsearch_module.redis_client.set.assert_called_once()
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch):
factory = elasticsearch_module.ElasticSearchVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(elasticsearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(
elasticsearch_module,
"current_app",
SimpleNamespace(
config={
"ELASTICSEARCH_USE_CLOUD": False,
"ELASTICSEARCH_HOST": "es-host",
"ELASTICSEARCH_PORT": 9200,
"ELASTICSEARCH_USERNAME": "elastic",
"ELASTICSEARCH_PASSWORD": "secret",
"ELASTICSEARCH_VERIFY_CERTS": False,
}
),
)
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
cfg = vector_cls.call_args.kwargs["config"]
assert cfg.use_cloud is False
assert vector_cls.call_args.kwargs["index_name"] == "EXISTING_COLLECTION"
monkeypatch.setattr(
elasticsearch_module,
"current_app",
SimpleNamespace(
config={
"ELASTICSEARCH_USE_CLOUD": True,
"ELASTICSEARCH_CLOUD_URL": "https://cloud.elastic",
"ELASTICSEARCH_API_KEY": "api-key",
"ELASTICSEARCH_VERIFY_CERTS": True,
}
),
)
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_2 == "vector"
cfg = vector_cls.call_args.kwargs["config"]
assert cfg.use_cloud is True
assert cfg.cloud_url == "https://cloud.elastic"
assert dataset_without_index.index_struct is not None
monkeypatch.setattr(
elasticsearch_module,
"current_app",
SimpleNamespace(
config={
"ELASTICSEARCH_USE_CLOUD": True,
"ELASTICSEARCH_CLOUD_URL": None,
"ELASTICSEARCH_HOST": "fallback-host",
"ELASTICSEARCH_PORT": 9201,
"ELASTICSEARCH_USERNAME": "elastic",
"ELASTICSEARCH_PASSWORD": "secret",
}
),
)
with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls:
factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
cfg = vector_cls.call_args.kwargs["config"]
assert cfg.use_cloud is False
assert cfg.host == "fallback-host"

View File

@ -0,0 +1,371 @@
import importlib
import json
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_hologres_modules():
holo_module = types.ModuleType("holo_search_sdk")
holo_types_module = types.ModuleType("holo_search_sdk.types")
holo_types_module.BaseQuantizationType = str
holo_types_module.DistanceType = str
holo_types_module.TokenizerType = str
def _connect(**kwargs):
client = MagicMock()
client.kwargs = kwargs
client.connect = MagicMock()
client.check_table_exist = MagicMock(return_value=False)
client.open_table = MagicMock(return_value=MagicMock())
client.execute = MagicMock(return_value=[])
client.drop_table = MagicMock()
return client
holo_module.connect = MagicMock(side_effect=_connect)
return {
"holo_search_sdk": holo_module,
"holo_search_sdk.types": holo_types_module,
}
@pytest.fixture
def hologres_module(monkeypatch):
for name, module in _build_fake_hologres_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.hologres.hologres_vector as module
return importlib.reload(module)
def _valid_config(module):
return module.HologresVectorConfig(
host="localhost",
port=80,
database="dify",
access_key_id="ak",
access_key_secret="sk",
schema_name="public",
tokenizer="jieba",
distance_method="Cosine",
base_quantization_type="rabitq",
max_degree=64,
ef_construction=400,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config HOLOGRES_HOST is required"),
("database", "", "config HOLOGRES_DATABASE is required"),
("access_key_id", "", "config HOLOGRES_ACCESS_KEY_ID is required"),
("access_key_secret", "", "config HOLOGRES_ACCESS_KEY_SECRET is required"),
],
)
def test_hologres_config_validation(hologres_module, field, value, message):
values = _valid_config(hologres_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
hologres_module.HologresVectorConfig.model_validate(values)
def test_init_client_and_get_type(hologres_module):
vector = hologres_module.HologresVector("Collection_One", _valid_config(hologres_module))
hologres_module.holo.connect.assert_called_once_with(
host="localhost",
port=80,
database="dify",
access_key_id="ak",
access_key_secret="sk",
schema="public",
)
vector._client.connect.assert_called_once()
assert vector.table_name == "embedding_collection_one"
assert vector.get_type() == hologres_module.VectorType.HOLOGRES
def test_create_delegates_collection_creation_and_upsert(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
result = vector.create(docs, [[0.1, 0.2]])
assert result is None
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_returns_empty_for_empty_documents(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
assert vector.add_texts([], []) == []
vector._client.open_table.assert_not_called()
def test_add_texts_batches_and_serializes_metadata(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
table = vector._client.open_table.return_value
documents = [
Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}", "document_id": f"document-{i}"})
for i in range(100)
]
documents.append(SimpleNamespace(page_content="doc-100", metadata=None))
embeddings = [[float(i)] for i in range(len(documents))]
ids = vector.add_texts(documents, embeddings)
assert ids[:2] == ["id-0", "id-1"]
assert ids[-1] == ""
assert len(ids) == 101
assert vector._client.open_table.call_count == 2
assert table.upsert_multi.call_count == 2
first_call = table.upsert_multi.call_args_list[0].kwargs
second_call = table.upsert_multi.call_args_list[1].kwargs
assert first_call["index_column"] == "id"
assert first_call["column_names"] == ["id", "text", "meta", "embedding"]
assert first_call["update_columns"] == ["text", "meta", "embedding"]
assert len(first_call["values"]) == 100
assert json.loads(first_call["values"][0][2]) == {"doc_id": "id-0", "document_id": "document-0"}
assert second_call["values"][0][0] == ""
assert second_call["values"][0][2] == "{}"
def test_text_exists_handles_missing_and_present_tables(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.side_effect = [False, True]
vector._client.execute.return_value = [(1,)]
assert vector.text_exists("seg-1") is False
assert vector.text_exists("seg-1") is True
vector._client.execute.assert_called_once()
def test_get_ids_by_metadata_field_returns_ids_or_none(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.execute.side_effect = [[("id-1",), ("id-2",)], []]
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
def test_delete_by_ids_branches(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector.delete_by_ids([])
vector._client.check_table_exist.assert_not_called()
vector._client.check_table_exist.return_value = False
vector.delete_by_ids(["id-1"])
vector._client.execute.assert_not_called()
vector._client.check_table_exist.return_value = True
vector.delete_by_ids(["id-1", "id-2"])
vector._client.execute.assert_called_once()
def test_delete_by_metadata_field_branches(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = False
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.execute.assert_not_called()
vector._client.check_table_exist.return_value = True
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.execute.assert_called_once()
def test_search_by_vector_returns_empty_when_table_missing(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = False
assert vector.search_by_vector([0.1, 0.2]) == []
def test_search_by_vector_applies_filter_and_processes_results(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = True
table = vector._client.open_table.return_value
query = MagicMock()
table.search_vector.return_value = query
query.select.return_value = query
query.limit.return_value = query
query.where.return_value = query
query.fetchall.return_value = [
(0.2, "seg-1", "doc-1", '{"doc_id":"seg-1","document_id":"doc-1"}'),
(0.9, "seg-2", "doc-2", {"doc_id": "seg-2", "document_id": "doc-2"}),
]
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=2,
score_threshold=0.5,
document_ids_filter=["doc-1"],
)
assert len(docs) == 1
assert docs[0].page_content == "doc-1"
assert docs[0].metadata["doc_id"] == "seg-1"
assert docs[0].metadata["score"] == pytest.approx(0.8)
table.search_vector.assert_called_once()
query.where.assert_called_once()
def test_search_by_full_text_returns_empty_when_table_missing(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = False
assert vector.search_by_full_text("query") == []
def test_search_by_full_text_applies_filter_and_processes_results(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.return_value = True
table = vector._client.open_table.return_value
search_query = MagicMock()
table.search_text.return_value = search_query
search_query.limit.return_value = search_query
search_query.where.return_value = search_query
search_query.fetchall.return_value = [
("seg-1", "doc-1", '{"doc_id":"seg-1"}', [0.1], 0.95),
("seg-2", "doc-2", {"doc_id": "seg-2"}, [0.2], 0.7),
]
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["doc-1"])
assert len(docs) == 2
assert docs[0].metadata["doc_id"] == "seg-1"
assert docs[0].metadata["score"] == pytest.approx(0.95)
assert docs[1].metadata["score"] == pytest.approx(0.7)
table.search_text.assert_called_once()
search_query.where.assert_called_once()
def test_delete_handles_existing_and_missing_tables(hologres_module):
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.side_effect = [False, True]
vector.delete()
vector._client.drop_table.assert_not_called()
vector.delete()
vector._client.drop_table.assert_called_once_with(vector.table_name)
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._create_collection(3)
vector._client.check_table_exist.assert_not_called()
hologres_module.redis_client.set.assert_not_called()
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
monkeypatch.setattr(hologres_module.time, "sleep", MagicMock())
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.side_effect = [False, False, True]
table = vector._client.open_table.return_value
vector._create_collection(3)
vector._client.execute.assert_called_once()
table.set_vector_index.assert_called_once_with(
column="embedding",
distance_method="Cosine",
base_quantization_type="rabitq",
max_degree=64,
ef_construction=400,
use_reorder=True,
)
table.create_text_index.assert_called_once_with(
index_name="ft_idx_collection_one",
column="text",
tokenizer="jieba",
)
hologres_module.redis_client.set.assert_called_once()
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock())
monkeypatch.setattr(hologres_module.time, "sleep", MagicMock())
vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module))
vector._client.check_table_exist.side_effect = [False] + [False] * 15
with pytest.raises(RuntimeError, match="was not ready after 30s"):
vector._create_collection(3)
hologres_module.redis_client.set.assert_not_called()
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch):
factory = hologres_module.HologresVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(hologres_module.Dataset, "gen_collection_name_by_id", lambda _id: "generated_collection")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_HOST", "127.0.0.1")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_PORT", 80)
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DATABASE", "dify")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_ID", "ak")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_SECRET", "sk")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_SCHEMA", "public")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_TOKENIZER", "jieba")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DISTANCE_METHOD", "Cosine")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq")
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_MAX_DEGREE", 64)
monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_EF_CONSTRUCTION", 400)
with patch.object(hologres_module, "HologresVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "generated_collection"
generated_config = vector_cls.call_args_list[1].kwargs["config"]
assert generated_config.host == "127.0.0.1"
assert generated_config.database == "dify"
assert generated_config.access_key_id == "ak"
assert json.loads(dataset_without_index.index_struct) == {
"type": hologres_module.VectorType.HOLOGRES,
"vector_store": {"class_prefix": "generated_collection"},
}

View File

@ -0,0 +1,243 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_elasticsearch_modules():
elasticsearch = types.ModuleType("elasticsearch")
class Elasticsearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.index = MagicMock()
self.exists = MagicMock(return_value=False)
self.delete = MagicMock()
self.search = MagicMock(return_value={"hits": {"hits": []}})
self.indices = SimpleNamespace(
refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock()
)
elasticsearch.Elasticsearch = Elasticsearch
return {"elasticsearch": elasticsearch}
@pytest.fixture
def huawei_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module
return importlib.reload(module)
def _config(module):
return module.HuaweiCloudVectorConfig(hosts="http://localhost:9200", username="user", password="pass")
def test_create_ssl_context(huawei_module):
ctx = huawei_module.create_ssl_context()
assert ctx.check_hostname is False
assert ctx.verify_mode == huawei_module.ssl.CERT_NONE
def test_huawei_config_validation_and_params(huawei_module):
with pytest.raises(ValidationError, match="HOSTS is required"):
huawei_module.HuaweiCloudVectorConfig.model_validate({"hosts": ""})
config = _config(huawei_module)
params = config.to_elasticsearch_params()
assert params["hosts"] == ["http://localhost:9200"]
assert params["basic_auth"] == ("user", "pass")
config = huawei_module.HuaweiCloudVectorConfig(hosts="host1,host2", username=None, password=None)
params = config.to_elasticsearch_params()
assert "basic_auth" not in params
def test_init_get_type_and_add_texts(huawei_module):
vector = huawei_module.HuaweiCloudVector("COLLECTION", _config(huawei_module))
assert vector._collection_name == "collection"
assert vector.get_type() == huawei_module.VectorType.HUAWEI_CLOUD
vector._get_uuids = MagicMock(return_value=["id-1", "id-2"])
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert vector._client.index.call_count == 2
vector._client.indices.refresh.assert_called_once_with(index="collection")
def test_crud_methods(huawei_module):
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
vector._client.exists.return_value = True
assert vector.text_exists("id-1") is True
vector.delete_by_ids([])
vector._client.delete.assert_not_called()
vector.delete_by_ids(["id-1"])
vector._client.delete.assert_called_once_with(index="collection", id="id-1")
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}}
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "x")
vector.delete_by_ids.assert_called_once_with(["id-1"])
vector.delete_by_ids.reset_mock()
vector._client.search.return_value = {"hits": {"hits": []}}
vector.delete_by_metadata_field("doc_id", "x")
vector.delete_by_ids.assert_not_called()
vector.delete()
vector._client.indices.delete.assert_called_once_with(index="collection")
def test_search_by_vector_and_full_text(huawei_module):
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_score": 0.9,
"_source": {
huawei_module.Field.CONTENT_KEY: "doc-a",
huawei_module.Field.VECTOR: [0.1],
huawei_module.Field.METADATA_KEY: {"doc_id": "1"},
},
},
{
"_score": 0.1,
"_source": {
huawei_module.Field.CONTENT_KEY: "doc-b",
huawei_module.Field.VECTOR: [0.2],
huawei_module.Field.METADATA_KEY: {"doc_id": "2"},
},
},
]
}
}
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
query_body = vector._client.search.call_args.kwargs["body"]
assert query_body["query"]["vector"][huawei_module.Field.VECTOR]["topk"] == 2
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
huawei_module.Field.CONTENT_KEY: "text-hit",
huawei_module.Field.VECTOR: [0.3],
huawei_module.Field.METADATA_KEY: {"doc_id": "3"},
}
}
]
}
}
docs = vector.search_by_full_text("hello", top_k=3)
assert len(docs) == 1
assert docs[0].page_content == "text-hit"
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch):
class FakeDocument:
def __init__(self, page_content, vector, metadata):
self.page_content = page_content
self.vector = vector
self.metadata = None
monkeypatch.setattr(huawei_module, "Document", FakeDocument)
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_score": 0.9,
"_source": {
huawei_module.Field.CONTENT_KEY: "doc-a",
huawei_module.Field.VECTOR: [0.1],
huawei_module.Field.METADATA_KEY: {"doc_id": "1"},
},
}
]
}
}
docs = vector.search_by_vector([0.1, 0.2], top_k=1, score_threshold=0.5)
assert docs == []
def test_create_and_create_collection_paths(huawei_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(huawei_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(huawei_module.redis_client, "set", MagicMock())
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="a", metadata={"doc_id": "1"})]
vector.create(docs, [[0.1]])
vector.create_collection.assert_called_once()
vector.add_texts.assert_called_once_with(docs, [[0.1]])
vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module))
monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_not_called()
monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=None))
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2]], [{}])
vector._client.indices.create.assert_called_once()
kwargs = vector._client.indices.create.call_args.kwargs
mappings = kwargs["mappings"]
assert mappings["properties"][huawei_module.Field.VECTOR]["dimension"] == 2
assert kwargs["settings"] == {"index.vector": True}
huawei_module.redis_client.set.assert_called_once()
def test_huawei_factory_branches(huawei_module, monkeypatch):
factory = huawei_module.HuaweiCloudVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(huawei_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_HOSTS", "http://huawei-es:9200")
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_USER", "user")
monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_PASSWORD", "pass")
with patch.object(huawei_module, "HuaweiCloudVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["index_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["index_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,412 @@
import importlib
import sys
import types
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_iris_module():
iris = types.ModuleType("iris")
def connect(**_kwargs):
conn = MagicMock()
conn.cursor.return_value = MagicMock()
return conn
iris.connect = MagicMock(side_effect=connect)
return iris
@pytest.fixture
def iris_module(monkeypatch):
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
import core.rag.datasource.vdb.iris.iris_vector as module
reloaded = importlib.reload(module)
reloaded._pool_instance = None
return reloaded
def _config(module, **overrides):
values = {
"IRIS_HOST": "localhost",
"IRIS_SUPER_SERVER_PORT": 1972,
"IRIS_USER": "user",
"IRIS_PASSWORD": "pass",
"IRIS_DATABASE": "db",
"IRIS_SCHEMA": "schema",
"IRIS_CONNECTION_URL": "url",
"IRIS_MIN_CONNECTION": 1,
"IRIS_MAX_CONNECTION": 2,
"IRIS_TEXT_INDEX": True,
"IRIS_TEXT_INDEX_LANGUAGE": "en",
}
values.update(overrides)
return module.IrisVectorConfig.model_validate(values)
def test_get_iris_pool_singleton(iris_module):
iris_module._pool_instance = None
cfg = _config(iris_module)
with patch.object(iris_module, "IrisConnectionPool", return_value="pool") as pool_cls:
pool_1 = iris_module.get_iris_pool(cfg)
pool_2 = iris_module.get_iris_pool(cfg)
assert pool_1 == "pool"
assert pool_2 == "pool"
pool_cls.assert_called_once_with(cfg)
@pytest.fixture
def pool_with_min_max(iris_module):
cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3)
with patch.object(iris_module.IrisConnectionPool, "_create_connection", return_value=MagicMock()) as create_conn:
pool = iris_module.IrisConnectionPool(cfg)
yield pool, create_conn
def test_pool_initialization_respects_min_max(pool_with_min_max):
pool, create_conn = pool_with_min_max
assert len(pool._pool) == 2
assert create_conn.call_count == 2
@pytest.fixture
def pool_for_get_connection(iris_module):
cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3)
pool = iris_module.IrisConnectionPool(cfg)
return pool
def test_get_connection_returns_existing_and_increments(pool_for_get_connection):
pool = pool_for_get_connection
conn = MagicMock()
pool._pool = [conn]
pool._in_use = 0
assert pool.get_connection() is conn
assert pool._in_use == 1
def test_get_connection_creates_new_when_empty(pool_for_get_connection):
pool = pool_for_get_connection
pool._pool = []
pool._in_use = 0
pool._create_connection = MagicMock(return_value="new-conn")
assert pool.get_connection() == "new-conn"
def test_get_connection_raises_when_exhausted(pool_for_get_connection):
pool = pool_for_get_connection
pool._pool = []
pool._in_use = pool._max_size
with pytest.raises(RuntimeError, match="exhausted"):
pool.get_connection()
@pytest.fixture
def pool_for_return_connection(iris_module):
cfg = _config(iris_module)
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
pool = iris_module.IrisConnectionPool(cfg)
return pool
def test_return_connection_adds_healthy(pool_for_return_connection):
pool = pool_for_return_connection
pool._in_use = 1
conn = MagicMock()
cursor = MagicMock()
conn.cursor.return_value = cursor
pool.return_connection(conn)
assert pool._pool[-1] is conn
assert pool._in_use == 0
def test_return_connection_replaces_bad(pool_for_return_connection):
pool = pool_for_return_connection
pool._in_use = 1
bad_conn = MagicMock()
bad_cursor = MagicMock()
bad_cursor.execute.side_effect = OSError("bad")
bad_conn.cursor.return_value = bad_cursor
replacement = MagicMock()
pool._create_connection = MagicMock(return_value=replacement)
pool.return_connection(bad_conn)
bad_conn.close.assert_called_once()
assert pool._pool[-1] is replacement
assert pool._in_use == 0
def test_return_connection_ignores_none(pool_for_return_connection):
pool = pool_for_return_connection
before = len(pool._pool)
pool.return_connection(None)
assert len(pool._pool) == before
@pytest.fixture
def pool_for_schema_and_close(iris_module):
cfg = _config(iris_module)
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
pool = iris_module.IrisConnectionPool(cfg)
conn = MagicMock()
cursor = MagicMock()
conn.cursor.return_value = cursor
pool._pool = [conn]
return pool, conn, cursor
def test_ensure_schema_exists_cached_noop(pool_for_schema_and_close):
pool, conn, cursor = pool_for_schema_and_close
pool._schemas_initialized = {"cached_schema"}
pool.ensure_schema_exists("cached_schema")
cursor.execute.assert_not_called()
def test_ensure_schema_exists_creates_new(pool_for_schema_and_close):
pool, conn, cursor = pool_for_schema_and_close
pool._schemas_initialized = set()
cursor.fetchone.return_value = (0,)
pool.ensure_schema_exists("new_schema")
assert "new_schema" in pool._schemas_initialized
assert any("CREATE SCHEMA" in call.args[0] for call in cursor.execute.call_args_list)
conn.commit.assert_called_once()
def test_ensure_schema_exists_existing_no_commit(pool_for_schema_and_close):
pool, conn, cursor = pool_for_schema_and_close
pool._schemas_initialized = set()
cursor.fetchone.return_value = (1,)
pool.ensure_schema_exists("existing_schema")
conn.commit.assert_not_called()
def test_ensure_schema_exists_rollback_on_error(pool_for_schema_and_close):
pool, conn, cursor = pool_for_schema_and_close
pool._schemas_initialized = set()
cursor.execute.side_effect = RuntimeError("schema failure")
with pytest.raises(RuntimeError, match="schema failure"):
pool.ensure_schema_exists("broken_schema")
conn.rollback.assert_called()
def test_close_all_closes_and_resets(iris_module):
cfg = _config(iris_module)
with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None):
pool = iris_module.IrisConnectionPool(cfg)
conn = MagicMock()
conn_2 = MagicMock()
conn_2.close.side_effect = OSError("close fail")
pool._pool = [conn, conn_2]
pool._schemas_initialized = {"x"}
pool.close_all()
assert pool._pool == []
assert pool._in_use == 0
assert pool._schemas_initialized == set()
def test_iris_vector_init_get_cursor_and_create(iris_module):
pool = MagicMock()
pool.get_connection.return_value = MagicMock()
with patch.object(iris_module, "get_iris_pool", return_value=pool):
vector = iris_module.IrisVector("collection", _config(iris_module))
assert vector.table_name == "EMBEDDING_COLLECTION"
assert vector.schema == "schema"
assert vector.get_type() == iris_module.VectorType.IRIS
conn = MagicMock()
cursor = MagicMock()
conn.cursor.return_value = cursor
vector.pool.get_connection.return_value = conn
with vector._get_cursor() as got_cursor:
assert got_cursor is cursor
conn.commit.assert_called_once()
vector.pool.return_connection.assert_called_with(conn)
conn = MagicMock()
cursor = MagicMock()
conn.cursor.return_value = cursor
vector.pool.get_connection.return_value = conn
with pytest.raises(RuntimeError, match="boom"):
with vector._get_cursor():
raise RuntimeError("boom")
conn.rollback.assert_called_once()
vector._create_collection = MagicMock()
vector.add_texts = MagicMock(return_value=["id-1"])
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
assert vector.create(docs, [[0.1, 0.2]]) == ["id-1"]
vector._create_collection.assert_called_once_with(2)
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", _config(iris_module))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(iris_module.uuid, "uuid4", lambda: "generated-id")
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
SimpleNamespace(page_content="b", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "generated-id"]
assert cursor.execute.call_count == 2
cursor.fetchone.return_value = (1,)
assert vector.text_exists("id-1") is True
cursor.fetchone.return_value = None
assert vector.text_exists("id-2") is False
vector._get_cursor = MagicMock(side_effect=RuntimeError("db down"))
assert vector.text_exists("id-3") is False
vector._get_cursor = _cursor_ctx
vector.delete_by_ids([])
before = cursor.execute.call_count
vector.delete_by_ids(["id-1", "id-2"])
assert cursor.execute.call_count == before + 1
vector.delete_by_metadata_field("document_id", "doc-1")
assert "meta LIKE" in cursor.execute.call_args.args[0]
cursor.fetchall.return_value = [
("id-1", "text-1", '{"document_id":"d-1"}', 0.9),
("id-2", "text-2", '{"document_id":"d-2"}', 0.2),
("id-x",),
]
docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
cfg = _config(iris_module, IRIS_TEXT_INDEX=True)
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", cfg)
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
cursor.execute.side_effect = None
cursor.fetchall.return_value = [
("id-1", "text-1", '{"document_id":"d-1"}', 0.7),
("id-2", "text-2", "{}", None),
]
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 2
assert docs[0].metadata["score"] == pytest.approx(0.7)
assert docs[1].metadata["score"] == pytest.approx(0.0)
cursor.reset_mock()
cursor.execute.side_effect = [RuntimeError("rank failed"), None]
cursor.fetchall.return_value = [("id-3", "text-3", "{}", 0.5)]
docs = vector.search_by_full_text("query", top_k=1)
assert len(docs) == 1
assert cursor.execute.call_count == 2
cfg_like = _config(iris_module, IRIS_TEXT_INDEX=False)
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector_like = iris_module.IrisVector("collection", cfg_like)
vector_like._get_cursor = _cursor_ctx
fake_libs = types.ModuleType("libs")
fake_helper = types.ModuleType("libs.helper")
fake_helper.escape_like_pattern = lambda value: value.replace("%", "\\%")
monkeypatch.setitem(sys.modules, "libs", fake_libs)
monkeypatch.setitem(sys.modules, "libs.helper", fake_helper)
cursor.reset_mock()
cursor.execute.side_effect = None
cursor.fetchall.return_value = []
assert vector_like.search_by_full_text("100%", top_k=1) == []
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch):
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector.delete()
assert "DROP TABLE" in cursor.execute.call_args.args[0]
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(iris_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(iris_module.redis_client, "set", MagicMock())
monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(2)
cursor.execute.assert_called_once()
cursor.reset_mock()
monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=None))
vector.pool.ensure_schema_exists = MagicMock()
vector._create_collection(3)
assert cursor.execute.call_count == 3
iris_module.redis_client.set.assert_called_once()
cursor.reset_mock()
vector.config.IRIS_TEXT_INDEX = False
vector._create_collection(3)
assert cursor.execute.call_count == 2
factory = iris_module.IrisVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(iris_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(iris_module.dify_config, "IRIS_HOST", "localhost")
monkeypatch.setattr(iris_module.dify_config, "IRIS_SUPER_SERVER_PORT", 1972)
monkeypatch.setattr(iris_module.dify_config, "IRIS_USER", "user")
monkeypatch.setattr(iris_module.dify_config, "IRIS_PASSWORD", "pass")
monkeypatch.setattr(iris_module.dify_config, "IRIS_DATABASE", "db")
monkeypatch.setattr(iris_module.dify_config, "IRIS_SCHEMA", "schema")
monkeypatch.setattr(iris_module.dify_config, "IRIS_CONNECTION_URL", "url")
monkeypatch.setattr(iris_module.dify_config, "IRIS_MIN_CONNECTION", 1)
monkeypatch.setattr(iris_module.dify_config, "IRIS_MAX_CONNECTION", 2)
monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX", True)
monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX_LANGUAGE", "en")
with patch.object(iris_module, "IrisVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,394 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_opensearch_modules():
opensearchpy = types.ModuleType("opensearchpy")
opensearch_helpers = types.ModuleType("opensearchpy.helpers")
class BulkIndexError(Exception):
def __init__(self, errors):
super().__init__("bulk error")
self.errors = errors
class OpenSearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.indices = SimpleNamespace(
refresh=MagicMock(),
exists=MagicMock(return_value=False),
delete=MagicMock(),
create=MagicMock(),
)
self.bulk = MagicMock(return_value={"errors": False, "items": []})
self.search = MagicMock(return_value={"hits": {"hits": []}})
self.delete_by_query = MagicMock()
self.get = MagicMock(return_value={"_id": "id"})
self.exists = MagicMock(return_value=True)
opensearch_helpers.BulkIndexError = BulkIndexError
opensearch_helpers.bulk = MagicMock()
opensearchpy.OpenSearch = OpenSearch
opensearchpy.helpers = opensearch_helpers
return {
"opensearchpy": opensearchpy,
"opensearchpy.helpers": opensearch_helpers,
}
@pytest.fixture
def lindorm_module(monkeypatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.lindorm.lindorm_vector as module
return importlib.reload(module)
def _config(module):
return module.LindormVectorStoreConfig(
hosts="http://localhost:9200",
username="user",
password="pass",
using_ugc=False,
request_timeout=3.0,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("hosts", None, "config URL is required"),
("username", None, "config USERNAME is required"),
("password", None, "config PASSWORD is required"),
],
)
def test_lindorm_config_validation(lindorm_module, field, value, message):
values = _config(lindorm_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
lindorm_module.LindormVectorStoreConfig.model_validate(values)
def test_to_opensearch_params_and_init(lindorm_module):
cfg = _config(lindorm_module)
params = cfg.to_opensearch_params()
assert params["hosts"] == "http://localhost:9200"
assert params["http_auth"] == ("user", "pass")
vector = lindorm_module.LindormVectorStore("Collection", cfg, using_ugc=False)
assert vector._collection_name == "collection"
assert vector.get_type() == lindorm_module.VectorType.LINDORM
with pytest.raises(ValueError, match="routing_value"):
lindorm_module.LindormVectorStore("c", cfg, using_ugc=True)
vector_ugc = lindorm_module.LindormVectorStore("c", cfg, using_ugc=True, routing_value="ROUTE")
assert vector_ugc._routing == "route"
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1]])
vector.create_collection.assert_called_once_with([[0.1]], [{"doc_id": "id-1"}])
vector.add_texts.assert_called_once_with(docs, [[0.1]])
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
monkeypatch.setattr(lindorm_module.time, "sleep", MagicMock())
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
Document(page_content="c", metadata={"doc_id": "id-3"}),
]
embeddings = [[0.1], [0.2], [0.3]]
vector.add_texts(docs, embeddings, batch_size=2, timeout=9)
assert vector._client.bulk.call_count == 2
actions = vector._client.bulk.call_args_list[0].args[0]
assert actions[0]["index"]["routing"] == "route"
assert actions[1][lindorm_module.ROUTING_FIELD] == "route"
vector.refresh()
vector._client.indices.refresh.assert_called_once_with(index="collection")
def test_add_texts_error_paths(lindorm_module):
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
vector._client.bulk.return_value = {"errors": True, "items": [{"index": {"error": "boom"}}]}
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
with pytest.raises(Exception, match="RetryError"):
vector.add_texts(docs, [[0.1]], batch_size=1)
vector._client.bulk.side_effect = RuntimeError("bulk failed")
with pytest.raises(Exception, match="RetryError"):
vector.add_texts(docs, [[0.1]], batch_size=1)
def test_metadata_lookup_and_delete_by_metadata(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}}
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
assert ids == ["id-1", "id-2"]
query = vector._client.search.call_args.kwargs["body"]
must_conditions = query["query"]["bool"]["must"]
assert any("routing_field.keyword" in cond.get("term", {}) for cond in must_conditions)
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-1", "id-2"])
vector._client.search.return_value = {"hits": {"hits": []}}
vector.delete_by_ids.reset_mock()
vector.delete_by_metadata_field("document_id", "doc-2")
vector.delete_by_ids.assert_not_called()
def test_delete_by_ids_paths(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector.delete_by_ids([])
vector._client.indices.exists.assert_not_called()
vector._client.indices.exists.return_value = False
vector.delete_by_ids(["id-1"])
vector._client.indices.exists.return_value = True
vector._client.exists.side_effect = [True, False]
lindorm_module.helpers.bulk.reset_mock()
vector.delete_by_ids(["id-1", "id-2"])
lindorm_module.helpers.bulk.assert_called_once()
actions = lindorm_module.helpers.bulk.call_args.args[1]
assert len(actions) == 1
assert actions[0]["routing"] == "route"
lindorm_module.helpers.bulk.reset_mock()
lindorm_module.helpers.bulk.side_effect = lindorm_module.BulkIndexError(
errors=[
{"delete": {"status": 404, "_id": "id-404"}},
{"delete": {"status": 500, "_id": "id-500"}},
]
)
vector._client.exists.side_effect = [True]
vector.delete_by_ids(["id-1"])
def test_delete_and_text_exists(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector.delete()
vector._client.delete_by_query.assert_called_once()
vector._client.indices.refresh.assert_called_once_with(index="collection")
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
vector._client.indices.exists.return_value = True
vector.delete()
vector._client.indices.delete.assert_called_once_with(index="collection", params={"timeout": 60})
vector._client.indices.delete.reset_mock()
vector._client.indices.exists.return_value = False
vector.delete()
vector._client.indices.delete.assert_not_called()
assert vector.text_exists("id-1") is True
vector._client.get.side_effect = RuntimeError("missing")
assert vector.text_exists("id-1") is False
def test_search_by_vector_validation_and_success(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
with pytest.raises(ValueError, match="should be a list"):
vector.search_by_vector("bad")
with pytest.raises(ValueError, match="should be floats"):
vector.search_by_vector([0.1, "bad"])
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_score": 0.9,
"_source": {
lindorm_module.Field.CONTENT_KEY: "doc-a",
lindorm_module.Field.VECTOR: [0.1],
lindorm_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"},
},
},
{
"_score": 0.2,
"_source": {
lindorm_module.Field.CONTENT_KEY: "doc-b",
lindorm_module.Field.VECTOR: [0.2],
lindorm_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"},
},
},
]
}
}
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
call_kwargs = vector._client.search.call_args.kwargs
query = call_kwargs["body"]
assert "ext" in query
assert query["query"]["knn"][lindorm_module.Field.VECTOR]["filter"]["bool"]["must"]
assert call_kwargs["params"]["routing"] == "route"
vector._client.search.side_effect = RuntimeError("search failed")
with pytest.raises(RuntimeError, match="search failed"):
vector.search_by_vector([0.1])
def test_search_by_full_text_success_and_error(lindorm_module):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
lindorm_module.Field.CONTENT_KEY: "doc-a",
lindorm_module.Field.VECTOR: [0.1],
lindorm_module.Field.METADATA_KEY: {"doc_id": "1"},
}
}
]
}
}
docs = vector.search_by_full_text("hello", top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].page_content == "doc-a"
query = vector._client.search.call_args.kwargs["body"]
assert query["query"]["bool"]["filter"]
vector._client.search.side_effect = RuntimeError("full text failed")
with pytest.raises(RuntimeError, match="full text failed"):
vector.search_by_full_text("hello")
def test_create_collection_paths(lindorm_module, monkeypatch):
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
with pytest.raises(ValueError, match="cannot be empty"):
vector.create_collection([])
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(lindorm_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(lindorm_module.redis_client, "set", MagicMock())
monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection([[0.1, 0.2]])
vector._client.indices.create.assert_not_called()
monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=None))
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2]], index_params={"index_type": "ivf", "space_type": "cosine"})
vector._client.indices.create.assert_called_once()
body = vector._client.indices.create.call_args.kwargs["body"]
assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["name"] == "ivf"
assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["space_type"] == "cosine"
vector._client.indices.create.reset_mock()
vector._client.indices.exists.return_value = True
vector.create_collection([[0.1, 0.2]])
vector._client.indices.create.assert_not_called()
def test_lindorm_factory_branches(lindorm_module, monkeypatch):
factory = lindorm_module.LindormVectorStoreFactory()
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200")
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USERNAME", "user")
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_PASSWORD", "pass")
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_QUERY_TIMEOUT", 3.0)
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_INDEX_TYPE", "hnsw")
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_DISTANCE_TYPE", "l2")
monkeypatch.setattr(lindorm_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
dataset = SimpleNamespace(id="dataset-1", index_struct=None, index_struct_dict={})
embeddings = SimpleNamespace(embed_query=lambda _q: [0.1, 0.2, 0.3])
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", None)
with pytest.raises(ValueError, match="LINDORM_USING_UGC is not set"):
factory.init_vector(dataset, attributes=[], embeddings=embeddings)
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False)
dataset_existing_plain = SimpleNamespace(
id="dataset-1",
index_struct="{}",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}, "using_ugc": False},
)
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
result = factory.init_vector(dataset_existing_plain, attributes=[], embeddings=embeddings)
assert result == "vector"
assert store_cls.call_args.args[0] == "existing"
dataset_existing_ugc = SimpleNamespace(
id="dataset-1",
index_struct="{}",
index_struct_dict={
"vector_store": {"class_prefix": "ROUTING"},
"using_ugc": True,
"dimension": 1536,
"index_type": "hnsw",
"distance_type": "l2",
},
)
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
factory.init_vector(dataset_existing_ugc, attributes=[], embeddings=embeddings)
assert store_cls.call_args.args[0] == "ugc_index_1536_hnsw_l2"
assert store_cls.call_args.kwargs["routing_value"] == "ROUTING"
dataset_new = SimpleNamespace(id="dataset-2", index_struct=None, index_struct_dict={})
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", True)
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
factory.init_vector(dataset_new, attributes=[], embeddings=embeddings)
assert store_cls.call_args.args[0] == "ugc_index_3_hnsw_l2"
assert store_cls.call_args.kwargs["routing_value"] == "auto_collection"
assert dataset_new.index_struct is not None
dataset_new_plain = SimpleNamespace(id="dataset-3", index_struct=None, index_struct_dict={})
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False)
with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls:
factory.init_vector(dataset_new_plain, attributes=[], embeddings=embeddings)
assert store_cls.call_args.args[0] == "auto_collection"
assert store_cls.call_args.kwargs["routing_value"] is None

View File

@ -0,0 +1,252 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_mo_vector_modules():
mo_vector = types.ModuleType("mo_vector")
mo_vector.__path__ = []
mo_vector_client = types.ModuleType("mo_vector.client")
class MoVectorClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_full_text_index = MagicMock()
self.insert = MagicMock()
self.get = MagicMock(return_value=[])
self.delete = MagicMock()
self.query_by_metadata = MagicMock(return_value=[])
self.query = MagicMock(return_value=[])
self.full_text_query = MagicMock(return_value=[])
mo_vector_client.MoVectorClient = MoVectorClient
mo_vector.client = mo_vector_client
return {"mo_vector": mo_vector, "mo_vector.client": mo_vector_client}
@pytest.fixture
def matrixone_module(monkeypatch):
for name, module in _build_fake_mo_vector_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.matrixone.matrixone_vector as module
return importlib.reload(module)
def _valid_config(module):
return module.MatrixoneConfig(
host="localhost",
port=6001,
user="dump",
password="111",
database="dify",
metric="l2",
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config host is required"),
("port", 0, "config port is required"),
("user", "", "config user is required"),
("password", "", "config password is required"),
("database", "", "config database is required"),
],
)
def test_matrixone_config_validation(matrixone_module, field, value, message):
values = _valid_config(matrixone_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
matrixone_module.MatrixoneConfig.model_validate(values)
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module))
client = vector._get_client(dimension=3, create_table=True)
assert client.kwargs["table_name"] == "collection_1"
client.create_full_text_index.assert_called_once()
matrixone_module.redis_client.set.assert_called_once()
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module))
client = vector._get_client(dimension=3, create_table=True)
client.create_full_text_index.assert_not_called()
matrixone_module.redis_client.set.assert_not_called()
def test_ensure_client_initializes_client_for_decorated_methods(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = None
fake_client = MagicMock()
fake_client.get.return_value = [{"id": "seg-1"}]
vector._get_client = MagicMock(return_value=fake_client)
exists = vector.text_exists("seg-1")
assert exists is True
vector._get_client.assert_called_once_with(None, False)
def test_search_by_full_text_parses_metadata_and_applies_threshold(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
vector.client.full_text_query.return_value = [
SimpleNamespace(document="doc-a", metadata='{"doc_id":"1"}', distance=0.1),
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.7),
]
docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].page_content == "doc-a"
assert docs[0].metadata["doc_id"] == "1"
assert docs[0].metadata["score"] == pytest.approx(0.9)
assert vector.client.full_text_query.call_args.kwargs["filter"] == {"document_id": {"$in": ["doc-1"]}}
def test_get_type_and_create_delegate_to_add_texts(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
fake_client = MagicMock()
vector._get_client = MagicMock(return_value=fake_client)
vector.add_texts = MagicMock(return_value=["seg-1"])
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
result = vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == "matrixone"
assert result == ["seg-1"]
vector._get_client.assert_called_once_with(2, True)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock())
failing_client = MagicMock()
failing_client.create_full_text_index.side_effect = RuntimeError("boom")
monkeypatch.setattr(matrixone_module, "MoVectorClient", MagicMock(return_value=failing_client))
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
client = vector._get_client(dimension=3, create_table=True)
assert client is failing_client
matrixone_module.redis_client.set.assert_not_called()
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [
Document(page_content="a", metadata={"doc_id": "doc-a", "document_id": "d-1"}),
Document(page_content="b", metadata={"document_id": "d-2"}),
SimpleNamespace(page_content="c", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
# For current prod code, only docs with metadata get ids, so only two ids
assert ids == ["doc-a", "generated-uuid"]
vector.client.insert.assert_called_once()
insert_kwargs = vector.client.insert.call_args.kwargs
# All lists passed to insert should be the same length
texts = insert_kwargs["texts"]
embeddings = insert_kwargs["embeddings"]
metadatas = insert_kwargs["metadatas"]
ids_insert = insert_kwargs["ids"]
assert len(texts) == len(embeddings) == len(metadatas) == len(docs)
# ids may be shorter than docs for current prod code, but should match number of docs with metadata
assert ids_insert == ["doc-a", "generated-uuid"]
def test_delete_and_metadata_methods(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
vector.client.query_by_metadata.return_value = [SimpleNamespace(id="seg-1"), SimpleNamespace(id="seg-2")]
vector.delete_by_ids([])
vector.client.delete.assert_not_called()
vector.delete_by_ids(["seg-1"])
vector.delete_by_metadata_field("document_id", "doc-1")
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
vector.delete()
assert ids == ["seg-1", "seg-2"]
assert vector.client.delete.call_count == 3
def test_search_by_vector_builds_documents(matrixone_module):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
vector.client.query.return_value = [
SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}),
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}),
]
docs = vector.search_by_vector([0.1, 0.2], top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 2
assert docs[0].page_content == "doc-a"
assert docs[1].metadata["doc_id"] == "2"
assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}}
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch):
factory = matrixone_module.MatrixoneVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(matrixone_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_HOST", "127.0.0.1")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PORT", 6001)
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_USER", "dump")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PASSWORD", "111")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_DATABASE", "dify")
monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_METRIC", "l2")
with patch.object(matrixone_module, "MatrixoneVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -1,18 +1,414 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig from core.rag.models.document import Document
def test_default_value(): def _build_fake_pymilvus_modules():
pymilvus = types.ModuleType("pymilvus")
pymilvus.__path__ = []
pymilvus_milvus_client = types.ModuleType("pymilvus.milvus_client")
pymilvus_orm = types.ModuleType("pymilvus.orm")
pymilvus_orm.__path__ = []
pymilvus_orm_types = types.ModuleType("pymilvus.orm.types")
class MilvusError(Exception):
pass
class MilvusClient:
def __init__(self, **kwargs):
self.init_kwargs = kwargs
self.has_collection = MagicMock(return_value=False)
self.describe_collection = MagicMock(
return_value={"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}]}
)
self.get_server_version = MagicMock(return_value="2.5.0")
self.insert = MagicMock(return_value=[1])
self.query = MagicMock(return_value=[])
self.delete = MagicMock()
self.drop_collection = MagicMock()
self.search = MagicMock(return_value=[[]])
self.create_collection = MagicMock()
class IndexParams:
def __init__(self):
self.indexes = []
def add_index(self, **kwargs):
self.indexes.append(kwargs)
class DataType:
JSON = "JSON"
VARCHAR = "VARCHAR"
INT64 = "INT64"
SPARSE_FLOAT_VECTOR = "SPARSE_FLOAT_VECTOR"
FLOAT_VECTOR = "FLOAT_VECTOR"
class FieldSchema:
def __init__(self, name, dtype, **kwargs):
self.name = name
self.dtype = dtype
self.kwargs = kwargs
class CollectionSchema:
def __init__(self, fields):
self.fields = fields
self.functions = []
def add_function(self, func):
self.functions.append(func)
class FunctionType:
BM25 = "BM25"
class Function:
def __init__(self, **kwargs):
self.kwargs = kwargs
def infer_dtype_bydata(_value):
return DataType.FLOAT_VECTOR
pymilvus.MilvusException = MilvusError
pymilvus.MilvusClient = MilvusClient
pymilvus.IndexParams = IndexParams
pymilvus.CollectionSchema = CollectionSchema
pymilvus.DataType = DataType
pymilvus.FieldSchema = FieldSchema
pymilvus.Function = Function
pymilvus.FunctionType = FunctionType
pymilvus_milvus_client.IndexParams = IndexParams
pymilvus_orm.types = pymilvus_orm_types
pymilvus_orm_types.infer_dtype_bydata = infer_dtype_bydata
# Attach submodules for dotted imports
pymilvus.milvus_client = pymilvus_milvus_client
pymilvus.orm = pymilvus_orm
return {
"pymilvus": pymilvus,
"pymilvus.milvus_client": pymilvus_milvus_client,
"pymilvus.orm": pymilvus_orm,
"pymilvus.orm.types": pymilvus_orm_types,
}
@pytest.fixture
def milvus_module(monkeypatch):
for name, module in _build_fake_pymilvus_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.milvus.milvus_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"uri": "http://localhost:19530",
"user": "root",
"password": "Milvus",
"database": "default",
"enable_hybrid_search": False,
"analyzer_params": None,
}
values.update(overrides)
return module.MilvusConfig.model_validate(values)
def test_config_validation_and_defaults(milvus_module):
valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"} valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"}
for key in valid_config: for key in valid_config:
config = valid_config.copy() config = valid_config.copy()
del config[key] del config[key]
with pytest.raises(ValidationError) as e: with pytest.raises(ValidationError) as e:
MilvusConfig.model_validate(config) milvus_module.MilvusConfig.model_validate(config)
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig.model_validate(valid_config) config = milvus_module.MilvusConfig.model_validate(valid_config)
assert config.database == "default" assert config.database == "default"
token_config = milvus_module.MilvusConfig.model_validate(
{"uri": "http://localhost:19530", "token": "token-value", "database": "db-1"}
)
assert token_config.token == "token-value"
def test_config_to_milvus_params(milvus_module):
config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}')
params = config.to_milvus_params()
assert params["uri"] == "http://localhost:19530"
assert params["db_name"] == "default"
assert params["analyzer_params"] == '{"tokenizer":"standard"}'
def test_init_client_supports_token_and_user_password(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
token_client = vector._init_client(
milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"})
)
assert token_client.init_kwargs == {"uri": "http://localhost:19530", "token": "abc", "db_name": "db"}
user_client = vector._init_client(_config(milvus_module))
assert user_client.init_kwargs["uri"] == "http://localhost:19530"
assert user_client.init_kwargs["user"] == "root"
assert user_client.init_kwargs["password"] == "Milvus"
def test_init_loads_fields_when_collection_exists(milvus_module):
client = milvus_module.MilvusClient(uri="http://localhost:19530")
client.has_collection.return_value = True
client.describe_collection.return_value = {
"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}, {"name": "sparse_vector"}]
}
with patch.object(milvus_module.MilvusVector, "_init_client", return_value=client):
with patch.object(milvus_module.MilvusVector, "_check_hybrid_search_support", return_value=False):
vector = milvus_module.MilvusVector("collection_1", _config(milvus_module))
assert "id" not in vector._fields
assert "content" in vector._fields
def test_load_collection_fields_from_argument_and_remote(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._client = MagicMock()
vector._collection_name = "collection_1"
vector._client.describe_collection.return_value = {"fields": [{"name": "id"}, {"name": "content"}]}
vector._load_collection_fields(["id", "metadata"])
assert vector._fields == ["metadata"]
vector._load_collection_fields()
assert vector._fields == ["content"]
def test_check_hybrid_search_support_branches(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._client = MagicMock()
vector._client_config = SimpleNamespace(enable_hybrid_search=False)
assert vector._check_hybrid_search_support() is False
vector._client_config = SimpleNamespace(enable_hybrid_search=True)
vector._client.get_server_version.return_value = "Zilliz Cloud 2.4"
assert vector._check_hybrid_search_support() is True
vector._client.get_server_version.return_value = "2.5.1"
assert vector._check_hybrid_search_support() is True
vector._client.get_server_version.return_value = "2.4.9"
assert vector._check_hybrid_search_support() is False
vector._client.get_server_version.side_effect = RuntimeError("boom")
assert vector._check_hybrid_search_support() is False
def test_get_type_and_create_delegate(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [SimpleNamespace(page_content="hello", metadata=None)]
vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == "milvus"
vector.create_collection.assert_called_once()
create_args = vector.create_collection.call_args.args
assert create_args[0] == [[0.1, 0.2]]
assert create_args[1] == [{}]
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_batches_and_raises_milvus_exception(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.insert.side_effect = [["id-1"], ["id-2"]]
docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"d-{i}"}) for i in range(1001)]
embeddings = [[0.1, 0.2] for _ in range(1001)]
ids = vector.add_texts(docs, embeddings)
assert ids == ["id-1", "id-2"]
assert vector._client.insert.call_count == 2
vector._client.insert.side_effect = milvus_module.MilvusException("insert failed")
with pytest.raises(milvus_module.MilvusException):
vector.add_texts([Document(page_content="x", metadata={})], [[0.1]])
def test_get_ids_and_delete_methods(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.query.return_value = [{"id": 1}, {"id": 2}]
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == [1, 2]
vector._client.query.return_value = []
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
vector._client.has_collection.return_value = True
vector.get_ids_by_metadata_field = MagicMock(return_value=[101, 102])
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.delete.assert_called_with(collection_name="collection_1", pks=[101, 102])
vector._client.delete.reset_mock()
vector._client.query.return_value = [{"id": 11}, {"id": 12}]
vector.delete_by_ids(["doc-a", "doc-b"])
vector._client.delete.assert_called_with(collection_name="collection_1", pks=[11, 12])
vector._client.has_collection.return_value = True
vector.delete()
vector._client.drop_collection.assert_called_once_with("collection_1", None)
def test_text_exists_and_field_exists(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._fields = ["content", "metadata"]
vector._client = MagicMock()
vector._client.has_collection.return_value = False
assert vector.text_exists("doc-1") is False
vector._client.has_collection.return_value = True
vector._client.query.return_value = [{"id": 1}]
assert vector.text_exists("doc-1") is True
vector._client.query.return_value = []
assert vector.text_exists("doc-1") is False
assert vector.field_exists("content") is True
assert vector.field_exists("unknown") is False
def test_process_search_results_and_search_methods(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._fields = ["content", "metadata", "sparse_vector"]
processed = vector._process_search_results(
[
[
{"entity": {"content": "doc-1", "metadata": {"doc_id": "1"}}, "distance": 0.9},
{"entity": {"content": "doc-2", "metadata": {"doc_id": "2"}}, "distance": 0.2},
]
],
[milvus_module.Field.CONTENT_KEY, milvus_module.Field.METADATA_KEY],
score_threshold=0.5,
)
assert len(processed) == 1
assert processed[0].metadata["score"] == 0.9
vector._client.search.return_value = [[{"entity": {"content": "doc"}, "distance": 0.8}]]
vector._process_search_results = MagicMock(return_value=["doc"])
docs = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["a", "b"], score_threshold=0.1)
assert docs == ["doc"]
assert vector._client.search.call_args.kwargs["filter"] == 'metadata["document_id"] in ["a", "b"]'
vector._hybrid_search_enabled = False
assert vector.search_by_full_text("query") == []
vector._hybrid_search_enabled = True
vector._fields = []
assert vector.search_by_full_text("query") == []
vector._fields = [milvus_module.Field.SPARSE_VECTOR]
vector._process_search_results = MagicMock(return_value=["full-text-doc"])
full_text_docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.2)
assert full_text_docs == ["full-text-doc"]
assert "document_id" in vector._client.search.call_args.kwargs["filter"]
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock())
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._consistency_level = "Session"
vector._client_config = _config(milvus_module)
vector._hybrid_search_enabled = False
vector._client = MagicMock()
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"})
vector._client.create_collection.assert_not_called()
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None))
vector._client.has_collection.return_value = True
vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"})
milvus_module.redis_client.set.assert_called()
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock())
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
vector._collection_name = "collection_1"
vector._consistency_level = "Session"
vector._client = MagicMock()
vector._client.has_collection.return_value = False
vector._load_collection_fields = MagicMock()
vector._client_config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}')
vector._hybrid_search_enabled = True
vector.create_collection(
embeddings=[[0.1, 0.2]],
metadatas=[{"doc_id": "1"}],
index_params={"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8}},
)
call_kwargs = vector._client.create_collection.call_args.kwargs
schema = call_kwargs["schema"]
index_params_obj = call_kwargs["index_params"]
field_names = [f.name for f in schema.fields]
assert milvus_module.Field.SPARSE_VECTOR in field_names
assert len(schema.functions) == 1
assert len(index_params_obj.indexes) == 2
assert call_kwargs["consistency_level"] == "Session"
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch):
factory = milvus_module.MilvusVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(milvus_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_URI", "http://localhost:19530")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_TOKEN", "")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_USER", "root")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_PASSWORD", "Milvus")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_DATABASE", "default")
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ENABLE_HYBRID_SEARCH", True)
monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ANALYZER_PARAMS", '{"tokenizer":"standard"}')
with patch.object(milvus_module, "MilvusVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,230 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_clickhouse_connect_module():
clickhouse_connect = types.ModuleType("clickhouse_connect")
class QueryResult:
def __init__(self, rows=None, named_rows=None):
self.row_count = len(rows or [])
self.result_rows = rows or []
self._named_rows = named_rows or []
def named_results(self):
return self._named_rows
class Client:
def __init__(self):
self.command = MagicMock()
self.query = MagicMock(return_value=QueryResult())
client = Client()
def get_client(**_kwargs):
return client
clickhouse_connect.get_client = get_client
clickhouse_connect.QueryResult = QueryResult
clickhouse_connect._fake_client = client
return clickhouse_connect
@pytest.fixture
def myscale_module(monkeypatch):
fake_module = _build_fake_clickhouse_connect_module()
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
import core.rag.datasource.vdb.myscale.myscale_vector as module
return importlib.reload(module)
def _config(module):
return module.MyScaleConfig(
host="localhost",
port=8123,
user="default",
password="",
database="dify",
fts_params="",
)
def test_escape_str_replaces_backslash_and_quote(myscale_module):
escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special")
assert escaped == "text with special"
def test_search_raises_for_invalid_top_k(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=0)
def test_search_builds_where_clause_for_cosine_threshold(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.query.return_value = myscale_module.get_client().query.return_value.__class__(
named_rows=[{"text": "doc-1", "vector": [0.1, 0.2], "metadata": {"doc_id": "seg-1"}}]
)
docs = vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=1, score_threshold=0.2)
assert len(docs) == 1
sql = vector._client.query.call_args.args[0]
assert "WHERE dist < 0.8" in sql
def test_delete_by_ids_short_circuits_on_empty_list(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.command.reset_mock()
vector.delete_by_ids([])
vector._client.command.assert_not_called()
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch):
factory = myscale_module.MyScaleVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(myscale_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_HOST", "localhost")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PORT", 8123)
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_USER", "default")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PASSWORD", "")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_DATABASE", "dify")
monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_FTS_PARAMS", "")
with patch.object(myscale_module, "MyScaleVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None
def test_init_and_get_type_set_expected_defaults(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
assert vector.get_type() == "myscale"
assert vector._vec_order == myscale_module.SortOrder.ASC
vector._client.command.assert_called_with("SET allow_experimental_object_type=1")
def test_create_calls_create_collection_and_add_texts(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._create_collection = MagicMock()
vector.add_texts = MagicMock(return_value=["seg-1"])
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
result = vector.create(docs, [[0.1, 0.2]])
assert result == ["seg-1"]
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once()
def test_create_collection_builds_expected_sql(myscale_module):
config = myscale_module.MyScaleConfig(
host="localhost",
port=8123,
user="default",
password="",
database="dify",
fts_params="tokenizer=unicode",
)
vector = myscale_module.MyScaleVector("collection_1", config)
vector._client.command.reset_mock()
vector._create_collection(3)
assert vector._client.command.call_count == 2
sql = vector._client.command.call_args_list[1].args[0]
assert "CREATE TABLE IF NOT EXISTS dify.collection_1" in sql
assert "CONSTRAINT cons_vec_len CHECK length(vector) = 3" in sql
assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [
Document(page_content=r"te'xt\1", metadata={"doc_id": "doc-a", "document_id": "d-1"}),
Document(page_content="text-2", metadata={"document_id": "d-2"}),
SimpleNamespace(page_content="text-3", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
sql = vector._client.command.call_args.args[0]
assert "INSERT INTO dify.collection_1" in sql
assert "te xt 1" in sql
def test_text_exists_and_metadata_operations(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.query.return_value = SimpleNamespace(row_count=1, result_rows=[("id-1",), ("id-2",)])
assert vector.text_exists("id-1") is True
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
vector.delete_by_ids(["id-1", "id-2"])
vector.delete_by_metadata_field("document_id", "doc-1")
assert vector._client.command.call_count >= 2
def test_search_delegation_methods(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._search = MagicMock(return_value=["result"])
result_vector = vector.search_by_vector([0.1, 0.2], top_k=2)
result_text = vector.search_by_full_text("hello", top_k=2)
assert result_vector == ["result"]
assert result_text == ["result"]
assert vector._search.call_count == 2
def test_search_with_document_filter_and_exception(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.query.return_value = SimpleNamespace(
named_results=lambda: [{"text": "doc", "vector": [0.1], "metadata": {"doc_id": "1"}}]
)
docs = vector._search(
"distance(vector, [0.1])",
myscale_module.SortOrder.ASC,
top_k=2,
document_ids_filter=["doc-1", "doc-2"],
)
assert len(docs) == 1
sql = vector._client.query.call_args.args[0]
assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql
vector._client.query.side_effect = RuntimeError("boom")
assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == []
def test_delete_drops_table(myscale_module):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
vector._client.command.reset_mock()
vector.delete()
vector._client.command.assert_called_once_with("DROP TABLE IF EXISTS dify.collection_1")

View File

@ -0,0 +1,553 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from core.rag.models.document import Document
def _build_fake_pyobvector_module():
pyobvector = types.ModuleType("pyobvector")
class VECTOR:
def __init__(self, dim):
self.dim = dim
def l2_distance(*_args, **_kwargs):
return "l2"
def cosine_distance(*_args, **_kwargs):
return "cosine"
def inner_product(*_args, **_kwargs):
return "inner_product"
class ObVecClient:
def __init__(self, **_kwargs):
self.metadata_obj = SimpleNamespace(tables={})
self.engine = MagicMock()
self.check_table_exists = MagicMock(return_value=False)
self.perform_raw_text_sql = MagicMock()
self.prepare_index_params = MagicMock()
self.create_table_with_index_params = MagicMock()
self.refresh_metadata = MagicMock()
self.insert = MagicMock()
self.refresh_index = MagicMock()
self.get = MagicMock()
self.delete = MagicMock()
self.set_ob_hnsw_ef_search = MagicMock()
self.ann_search = MagicMock(return_value=[])
self.drop_table_if_exist = MagicMock()
pyobvector.VECTOR = VECTOR
pyobvector.ObVecClient = ObVecClient
pyobvector.l2_distance = l2_distance
pyobvector.cosine_distance = cosine_distance
pyobvector.inner_product = inner_product
return pyobvector
@pytest.fixture
def oceanbase_module(monkeypatch):
monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module())
import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module
return importlib.reload(module)
def _config(module):
return module.OceanBaseVectorConfig(
host="127.0.0.1",
port=2881,
user="root",
password="secret",
database="test",
enable_hybrid_search=True,
batch_size=10,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config OCEANBASE_VECTOR_HOST is required"),
("port", 0, "config OCEANBASE_VECTOR_PORT is required"),
("user", "", "config OCEANBASE_VECTOR_USER is required"),
("database", "", "config OCEANBASE_VECTOR_DATABASE is required"),
],
)
def test_oceanbase_config_validation(oceanbase_module, field, value, message):
values = _config(oceanbase_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
oceanbase_module.OceanBaseVectorConfig.model_validate(values)
def test_init_rejects_invalid_collection_name(oceanbase_module):
with pytest.raises(ValueError, match="Invalid collection name"):
oceanbase_module.OceanBaseVector("invalid-name", _config(oceanbase_module))
def test_distance_to_score_for_supported_metrics(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(metric_type="l2")
assert vector._distance_to_score(3.0) == pytest.approx(0.25)
vector._config = SimpleNamespace(metric_type="cosine")
assert vector._distance_to_score(0.2) == pytest.approx(0.8)
vector._config = SimpleNamespace(metric_type="inner_product")
assert vector._distance_to_score(-0.2) == pytest.approx(0.2)
def test_get_distance_func_raises_for_unknown_metric(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(metric_type="manhattan")
with pytest.raises(ValueError, match="Unsupported metric_type"):
vector._get_distance_func()
def test_process_search_results_handles_json_and_score_threshold(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
rows = [
("doc-1", '{"doc_id":"1"}', 0.9),
("doc-2", "not-json", 0.8),
("doc-3", {"doc_id": "3"}, 0.3),
]
docs = vector._process_search_results(rows, score_threshold=0.5, score_key="rank")
assert len(docs) == 2
assert docs[0].metadata["doc_id"] == "1"
assert docs[0].metadata["rank"] == 0.9
assert docs[1].metadata["rank"] == 0.8
def test_search_by_vector_validates_document_id_format(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._hnsw_ef_search = -1
vector._config = SimpleNamespace(metric_type="cosine")
vector._client = MagicMock()
with pytest.raises(ValueError, match="Invalid document ID format"):
vector.search_by_vector([0.1, 0.2], document_ids_filter=["bad id"])
def test_search_by_full_text_returns_empty_when_disabled(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._hybrid_search_enabled = False
vector._collection_name = "collection_1"
assert vector.search_by_full_text("query") == []
def test_check_hybrid_search_support_uses_version_comment(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(enable_hybrid_search=True)
vector._client = MagicMock()
cursor = MagicMock()
cursor.fetchone.return_value = ("OceanBase_CE 4.3.5.1 (rxxxxxxxxx) (Built Mar 18 2025)",)
vector._client.perform_raw_text_sql.return_value = cursor
assert vector._check_hybrid_search_support() is True
cursor.fetchone.return_value = ("OceanBase_CE 4.3.4.0 (rxxxxxxxxx) (Built Mar 18 2025)",)
assert vector._check_hybrid_search_support() is False
def test_init_get_type_and_field_loading(oceanbase_module):
config = _config(oceanbase_module)
config.enable_hybrid_search = False
table = SimpleNamespace(columns=[SimpleNamespace(name="id"), SimpleNamespace(name="text")])
fake_client = oceanbase_module.ObVecClient()
fake_client.check_table_exists.return_value = True
fake_client.metadata_obj.tables = {"collection_1": table}
with patch.object(oceanbase_module, "ObVecClient", return_value=fake_client):
vector = oceanbase_module.OceanBaseVector("collection_1", config)
assert vector.get_type() == "oceanbase"
assert vector.field_exists("text") is True
def test_load_collection_fields_handles_missing_table_and_exception(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._fields = []
vector._client = MagicMock()
vector._client.metadata_obj.tables = {}
vector._load_collection_fields()
assert vector._fields == []
vector._client.metadata_obj.tables = {"collection_1": MagicMock(columns=MagicMock(side_effect=RuntimeError("x")))}
vector._load_collection_fields()
assert vector._fields == []
def test_create_delegates_to_collection_and_insert(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="text", metadata={"doc_id": "1"})]
vector.create(docs, [[0.1, 0.2]])
assert vector._vec_dim == 2
vector._create_collection.assert_called_once()
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._vec_dim = 2
vector._hybrid_search_enabled = False
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
vector._client = MagicMock()
vector.delete = MagicMock()
vector._load_collection_fields = MagicMock()
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection()
vector._client.check_table_exists.assert_not_called()
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
vector._client.check_table_exists.return_value = True
vector._create_collection()
vector.delete.assert_not_called()
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik")
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._vec_dim = 3
vector._hybrid_search_enabled = True
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
vector._client = MagicMock()
vector._client.check_table_exists.return_value = False
vector._client.perform_raw_text_sql.side_effect = [
[[None, None, None, None, None, None, "30"]],
None,
None,
]
index_params = MagicMock()
vector._client.prepare_index_params.return_value = index_params
vector.delete = MagicMock()
vector._load_collection_fields = MagicMock()
vector._create_collection()
vector.delete.assert_called_once()
vector._client.create_table_with_index_params.assert_called_once()
index_params.add_index.assert_called_once()
vector._client.refresh_metadata.assert_called_once_with(["collection_1"])
oceanbase_module.redis_client.set.assert_called_once()
def test_create_collection_error_paths(oceanbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._vec_dim = 2
vector._hybrid_search_enabled = True
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
vector._client = MagicMock()
vector._client.check_table_exists.return_value = False
vector._client.prepare_index_params.return_value = MagicMock()
vector.delete = MagicMock()
vector._load_collection_fields = MagicMock()
vector._client.perform_raw_text_sql.return_value = []
with pytest.raises(ValueError, match="ob_vector_memory_limit_percentage not found"):
vector._create_collection()
vector._client.perform_raw_text_sql.side_effect = [
[[None, None, None, None, None, None, "0"]],
RuntimeError("no privilege"),
]
with pytest.raises(Exception, match="Failed to set ob_vector_memory_limit_percentage"):
vector._create_collection()
vector._client.perform_raw_text_sql.side_effect = [[[None, None, None, None, None, None, "30"]]]
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "not-valid")
with pytest.raises(ValueError, match="Invalid OceanBase full-text parser"):
vector._create_collection()
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock())
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik")
monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim))
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._vec_dim = 2
vector._hybrid_search_enabled = True
vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64)
vector._client = MagicMock()
vector._client.check_table_exists.return_value = False
vector._client.prepare_index_params.return_value = MagicMock()
vector.delete = MagicMock()
vector._load_collection_fields = MagicMock()
vector._client.perform_raw_text_sql.side_effect = [
[[None, None, None, None, None, None, "30"]],
RuntimeError("fulltext failed"),
]
with pytest.raises(Exception, match="Failed to add fulltext index"):
vector._create_collection()
vector._hybrid_search_enabled = False
vector._client.perform_raw_text_sql.side_effect = [
[[None, None, None, None, None, None, "30"]],
SQLAlchemyError("metadata index failed"),
]
vector._create_collection()
vector._client.refresh_metadata.assert_called_once_with(["collection_1"])
def test_check_hybrid_search_support_false_and_exception(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(enable_hybrid_search=False)
vector._client = MagicMock()
assert vector._check_hybrid_search_support() is False
vector._config = SimpleNamespace(enable_hybrid_search=True)
vector._client.perform_raw_text_sql.side_effect = RuntimeError("boom")
assert vector._check_hybrid_search_support() is False
def test_add_texts_batches_refresh_and_exceptions(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._config = SimpleNamespace(batch_size=2, hnsw_refresh_threshold=2)
vector._client = MagicMock()
vector._get_uuids = MagicMock(return_value=["id-1", "id-2", "id-3"])
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
Document(page_content="c", metadata={"doc_id": "id-3"}),
]
vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert vector._client.insert.call_count == 2
vector._client.refresh_index.assert_called_once()
vector._client.insert.reset_mock()
vector._client.refresh_index.reset_mock()
vector._client.insert.side_effect = RuntimeError("insert failed")
with pytest.raises(Exception, match="Failed to insert batch"):
vector.add_texts([docs[0]], [[0.1]])
vector._client.insert.side_effect = None
vector._client.insert.return_value = None
vector._client.refresh_index.side_effect = SQLAlchemyError("refresh failed")
vector._config = SimpleNamespace(batch_size=10, hnsw_refresh_threshold=1)
vector._get_uuids.return_value = ["id-1"]
vector.add_texts([docs[0]], [[0.1]])
def test_text_exists_and_delete_by_ids(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector._client.get.return_value = SimpleNamespace(rowcount=1)
assert vector.text_exists("id-1") is True
vector._client.get.side_effect = RuntimeError("boom")
with pytest.raises(Exception, match="Failed to check text existence"):
vector.text_exists("id-1")
vector.delete_by_ids([])
vector._client.delete.assert_not_called()
vector._client.delete.side_effect = None
vector.delete_by_ids(["id-1"])
vector._client.delete.assert_called_once()
vector._client.delete.side_effect = RuntimeError("boom")
with pytest.raises(Exception, match="Failed to delete documents"):
vector.delete_by_ids(["id-1"])
def test_get_ids_and_delete_by_metadata_field(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
execute_result = [("id-1",), ("id-2",)]
conn = MagicMock()
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
conn.execute.return_value = execute_result
vector._client.engine.connect.return_value = conn
ids = vector.get_ids_by_metadata_field("document_id", "doc-1")
assert ids == ["id-1", "id-2"]
with pytest.raises(Exception, match="Failed to query documents by metadata field"):
vector.get_ids_by_metadata_field("bad key!", "doc-1")
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-1"])
vector.get_ids_by_metadata_field = MagicMock(return_value=[])
vector.delete_by_ids.reset_mock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_not_called()
def test_search_by_full_text_paths(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._hybrid_search_enabled = True
vector.field_exists = MagicMock(return_value=False)
assert vector.search_by_full_text("query") == []
vector.field_exists.return_value = True
vector._client = MagicMock()
conn = MagicMock()
tx = MagicMock()
tx.__enter__.return_value = tx
tx.__exit__.return_value = None
conn.begin.return_value = tx
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
conn.execute.return_value.fetchall.return_value = [("text-1", '{"doc_id":"1"}', 0.9)]
vector._client.engine.connect.return_value = conn
docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.9
with pytest.raises(Exception, match="Full-text search failed"):
vector.search_by_full_text("query", top_k=0)
def test_search_by_vector_paths(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._hnsw_ef_search = -1
vector._config = SimpleNamespace(metric_type="cosine")
vector._client = MagicMock()
vector._client.ann_search.return_value = [("doc-1", '{"doc_id":"1"}', 0.2)]
vector._process_search_results = MagicMock(return_value=["doc"])
docs = vector.search_by_vector(
[0.1, 0.2],
ef_search=10,
top_k=3,
score_threshold=0.1,
document_ids_filter=["good_id"],
)
assert docs == ["doc"]
vector._client.set_ob_hnsw_ef_search.assert_called_once_with(10)
with pytest.raises(ValueError, match="Invalid score_threshold parameter"):
vector.search_by_vector([0.1], score_threshold="x")
vector._client.ann_search.side_effect = RuntimeError("boom")
with pytest.raises(Exception, match="Vector search failed"):
vector.search_by_vector([0.1], score_threshold=0.1)
def test_get_distance_func_and_distance_to_score_errors(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._config = SimpleNamespace(metric_type="cosine")
assert vector._get_distance_func() is oceanbase_module.cosine_distance
vector._config = SimpleNamespace(metric_type="unknown")
with pytest.raises(ValueError, match="Unsupported metric_type"):
vector._distance_to_score(0.1)
def test_delete_success_and_exception(oceanbase_module):
vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector)
vector._collection_name = "collection_1"
vector._client = MagicMock()
vector.delete()
vector._client.drop_table_if_exist.assert_called_once_with("collection_1")
vector._client.drop_table_if_exist.side_effect = RuntimeError("boom")
with pytest.raises(Exception, match="Failed to delete collection"):
vector.delete()
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch):
factory = oceanbase_module.OceanBaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(oceanbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_HOST", "127.0.0.1")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PORT", 2881)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_USER", "root")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PASSWORD", "password")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_DATABASE", "test")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_ENABLE_HYBRID_SEARCH", True)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_BATCH_SIZE", 10)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_METRIC_TYPE", "cosine")
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_M", 16)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_CONSTRUCTION", 64)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_SEARCH", -1)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_POOL_SIZE", 5)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_MAX_OVERFLOW", 10)
monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_REFRESH_THRESHOLD", 1000)
with patch.object(oceanbase_module, "OceanBaseVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].args[0] == "existing_collection"
assert vector_cls.call_args_list[1].args[0] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,400 @@
import importlib
import sys
import types
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_psycopg2_modules():
psycopg2 = types.ModuleType("psycopg2")
psycopg2.__path__ = []
psycopg2_extras = types.ModuleType("psycopg2.extras")
psycopg2_pool = types.ModuleType("psycopg2.pool")
class SimpleConnectionPool:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.getconn = MagicMock()
self.putconn = MagicMock()
psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool
psycopg2_extras.execute_values = MagicMock()
psycopg2.pool = psycopg2_pool
psycopg2.extras = psycopg2_extras
return {
"psycopg2": psycopg2,
"psycopg2.pool": psycopg2_pool,
"psycopg2.extras": psycopg2_extras,
}
@pytest.fixture
def opengauss_module(monkeypatch):
for name, module in _build_fake_psycopg2_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.opengauss.opengauss as module
return importlib.reload(module)
def _config(module, *, enable_pq=False):
return module.OpenGaussConfig(
host="localhost",
port=6600,
user="postgres",
password="password",
database="dify",
min_connection=1,
max_connection=5,
enable_pq=enable_pq,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config OPENGAUSS_HOST is required"),
("port", 0, "config OPENGAUSS_PORT is required"),
("user", "", "config OPENGAUSS_USER is required"),
("password", "", "config OPENGAUSS_PASSWORD is required"),
("database", "", "config OPENGAUSS_DATABASE is required"),
("min_connection", 0, "config OPENGAUSS_MIN_CONNECTION is required"),
("max_connection", 0, "config OPENGAUSS_MAX_CONNECTION is required"),
],
)
def test_opengauss_config_validation(opengauss_module, field, value, message):
values = _config(opengauss_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
opengauss_module.OpenGaussConfig.model_validate(values)
def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_module):
values = _config(opengauss_module).model_dump()
values["min_connection"] = 6
values["max_connection"] = 5
with pytest.raises(ValidationError, match="OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION"):
opengauss_module.OpenGaussConfig.model_validate(values)
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
assert vector.table_name == "embedding_collection_1"
assert vector.get_type() == "opengauss"
assert vector.pool is pool
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=True))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector._create_index(1536)
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("enable_pq=on" in sql for sql in executed_sql)
assert any("SET hnsw_earlystop_threshold = 320" in sql for sql in executed_sql)
opengauss_module.redis_client.set.assert_called_once()
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector._create_index(3072)
cursor.execute.assert_not_called()
opengauss_module.redis_client.set.assert_called_once()
def test_search_by_vector_validates_top_k(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_vector([0.1, 0.2], top_k=0)
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
vector._get_cursor = MagicMock()
vector.delete_by_ids([])
vector._get_cursor.assert_not_called()
def test_get_cursor_closes_commits_and_returns_connection(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
pool = MagicMock()
conn = MagicMock()
cur = MagicMock()
pool.getconn.return_value = conn
conn.cursor.return_value = cur
vector.pool = pool
with vector._get_cursor() as got_cur:
assert got_cur is cur
cur.close.assert_called_once()
conn.commit.assert_called_once()
pool.putconn.assert_called_once_with(conn)
def test_create_calls_collection_insert_and_index(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
vector._create_index = MagicMock()
docs = [Document(page_content="text", metadata={"doc_id": "seg-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
vector._create_index.assert_called_once_with(2)
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
vector._get_cursor = MagicMock()
vector._create_index(1536)
vector._get_cursor.assert_not_called()
opengauss_module.redis_client.set.assert_not_called()
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector._create_index(1536)
sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql)
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
cursor = MagicMock()
opengauss_module.psycopg2.extras.execute_values.reset_mock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
docs = [
Document(page_content="text-1", metadata={"doc_id": "seg-1", "document_id": "d-1"}),
SimpleNamespace(page_content="text-2", metadata=None),
]
monkeypatch.setattr(opengauss_module.uuid, "uuid4", lambda: "generated-uuid")
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["seg-1"]
opengauss_module.psycopg2.extras.execute_values.assert_called_once()
def test_text_exists_and_get_by_ids(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.fetchone.return_value = ("seg-1",)
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
assert vector.text_exists("seg-1") is True
docs = vector.get_by_ids(["seg-1", "seg-2"])
assert len(docs) == 2
assert docs[0].page_content == "text-1"
def test_delete_and_metadata_field_queries(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector.delete_by_ids(["seg-1", "seg-2"])
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete()
sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in query for query in sql)
assert any("meta->>%s = %s" in query for query in sql)
assert any("DROP TABLE IF EXISTS embedding_collection_1" in query for query in sql)
def test_search_by_vector_and_full_text(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter(
[
({"doc_id": "1"}, "text-1", 0.1),
({"doc_id": "2"}, "text-2", 0.6),
]
)
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.8)])
full_docs = vector.search_by_full_text("hello world", top_k=2)
assert len(full_docs) == 1
assert full_docs[0].page_content == "full-text"
def test_search_by_full_text_validates_top_k(opengauss_module):
vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss)
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_full_text("query", top_k=0)
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock())
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(1536)
cursor.execute.assert_not_called()
monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None))
vector._create_collection(1536)
cursor.execute.assert_called_once()
opengauss_module.redis_client.set.assert_called_once()
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch):
factory = opengauss_module.OpenGaussFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(opengauss_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_HOST", "localhost")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PORT", 6600)
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_USER", "postgres")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PASSWORD", "password")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_DATABASE", "dify")
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MIN_CONNECTION", 1)
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MAX_CONNECTION", 5)
monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_ENABLE_PQ", False)
with patch.object(opengauss_module, "OpenGauss", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,360 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_opensearch_modules():
opensearchpy = types.ModuleType("opensearchpy")
opensearchpy_helpers = types.ModuleType("opensearchpy.helpers")
class BulkIndexError(Exception):
def __init__(self, errors):
super().__init__("bulk error")
self.errors = errors
class Urllib3AWSV4SignerAuth:
def __init__(self, credentials, region, service):
self.credentials = credentials
self.region = region
self.service = service
class Urllib3HttpConnection:
pass
class _IndicesClient:
def __init__(self):
self.exists = MagicMock(return_value=False)
self.create = MagicMock()
self.delete = MagicMock()
class OpenSearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.indices = _IndicesClient()
self.search = MagicMock(return_value={"hits": {"hits": []}})
self.get = MagicMock()
helpers = SimpleNamespace(bulk=MagicMock())
opensearchpy.OpenSearch = OpenSearch
opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth
opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection
opensearchpy.helpers = helpers
opensearchpy_helpers.BulkIndexError = BulkIndexError
return {
"opensearchpy": opensearchpy,
"opensearchpy.helpers": opensearchpy_helpers,
}
@pytest.fixture
def opensearch_module(monkeypatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.opensearch.opensearch_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"host": "localhost",
"port": 9200,
"secure": True,
"verify_certs": True,
"auth_method": "basic",
"user": "admin",
"password": "secret",
}
values.update(overrides)
return module.OpenSearchConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config OPENSEARCH_HOST is required"),
("port", 0, "config OPENSEARCH_PORT is required"),
],
)
def test_config_validation_required_fields(opensearch_module, field, value, message):
values = _config(opensearch_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
opensearch_module.OpenSearchConfig.model_validate(values)
def test_config_validation_for_aws_auth_and_https_fields(opensearch_module):
values = {
"host": "localhost",
"port": 9200,
"secure": True,
"verify_certs": True,
"auth_method": "aws_managed_iam",
"user": "admin",
"password": "secret",
}
with pytest.raises(ValidationError, match="OPENSEARCH_AWS_REGION"):
opensearch_module.OpenSearchConfig.model_validate(values)
values = _config(opensearch_module).model_dump()
values["OPENSEARCH_SECURE"] = False
values["OPENSEARCH_VERIFY_CERTS"] = True
with pytest.raises(ValidationError, match="verify_certs=True requires secure"):
opensearch_module.OpenSearchConfig.model_validate(values)
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch):
class _Session:
def get_credentials(self):
return "creds"
boto3 = types.ModuleType("boto3")
boto3.Session = _Session
monkeypatch.setitem(sys.modules, "boto3", boto3)
config = _config(
opensearch_module,
auth_method="aws_managed_iam",
aws_region="us-east-1",
aws_service="es",
)
auth = config.create_aws_managed_iam_auth()
assert auth.credentials == "creds"
assert auth.region == "us-east-1"
assert auth.service == "es"
def test_to_opensearch_params_supports_basic_and_aws(opensearch_module):
basic_params = _config(opensearch_module).to_opensearch_params()
assert basic_params["http_auth"] == ("admin", "secret")
aws_config = _config(
opensearch_module,
auth_method="aws_managed_iam",
aws_region="us-west-2",
aws_service="es",
)
with patch.object(opensearch_module.OpenSearchConfig, "create_aws_managed_iam_auth", return_value="iam-auth"):
aws_params = aws_config.to_opensearch_params()
assert aws_params["http_auth"] == "iam-auth"
def test_init_and_create_delegate_calls(opensearch_module):
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module))
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == "opensearch"
vector.create_collection.assert_called_once_with([[0.1, 0.2]], [{"doc_id": "seg-1"}])
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch):
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es"))
docs = [
Document(page_content="a", metadata={"doc_id": "1"}),
Document(page_content="b", metadata={"doc_id": "2"}),
]
monkeypatch.setattr(opensearch_module, "uuid4", lambda: SimpleNamespace(hex="generated-id"))
opensearch_module.helpers.bulk.reset_mock()
vector.add_texts(docs, [[0.1], [0.2]])
actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"]
assert len(actions) == 2
assert all("_id" in action for action in actions)
vector._client_config.aws_service = "aoss"
opensearch_module.helpers.bulk.reset_mock()
vector.add_texts(docs, [[0.3], [0.4]])
aoss_actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"]
assert all("_id" not in action for action in aoss_actions)
def test_metadata_lookup_and_delete_by_metadata_field(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}}
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
vector._client.search.return_value = {"hits": {"hits": []}}
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-1"])
def test_delete_by_ids_branches_and_bulk_error_handling(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
opensearch_module.helpers.bulk.reset_mock()
vector._client.indices.exists.return_value = False
vector.delete_by_ids(["doc-1"])
opensearch_module.helpers.bulk.assert_not_called()
vector._client.indices.exists.return_value = True
vector.get_ids_by_metadata_field = MagicMock(side_effect=[["es-1"], None])
vector.delete_by_ids(["doc-1", "doc-2"])
opensearch_module.helpers.bulk.assert_called_once()
opensearch_module.helpers.bulk.reset_mock()
vector.get_ids_by_metadata_field = MagicMock(return_value=["es-404"])
opensearch_module.helpers.bulk.side_effect = opensearch_module.BulkIndexError(
[{"delete": {"status": 404, "_id": "es-404"}}]
)
vector.delete_by_ids(["doc-404"])
assert opensearch_module.helpers.bulk.call_count == 1
opensearch_module.helpers.bulk.side_effect = None
def test_delete_and_text_exists(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
vector.delete()
vector._client.indices.delete.assert_called_once_with(index="collection_1", ignore_unavailable=True)
vector._client.get.return_value = {"_id": "id-1"}
assert vector.text_exists("id-1") is True
vector._client.get.side_effect = RuntimeError("not found")
assert vector.text_exists("id-1") is False
def test_search_by_vector_validates_and_builds_documents(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
with pytest.raises(ValueError, match="query_vector should be a list"):
vector.search_by_vector("not-a-list")
with pytest.raises(ValueError, match="should be floats"):
vector.search_by_vector([0.1, 1])
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
opensearch_module.Field.CONTENT_KEY: "doc-1",
opensearch_module.Field.METADATA_KEY: None,
},
"_score": 0.9,
},
{
"_source": {
opensearch_module.Field.CONTENT_KEY: "doc-2",
opensearch_module.Field.METADATA_KEY: {"doc_id": "2"},
},
"_score": 0.1,
},
]
}
}
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].page_content == "doc-1"
assert docs[0].metadata["score"] == pytest.approx(0.9)
vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-a", "doc-b"])
query = vector._client.search.call_args.kwargs["body"]
assert "script_score" in query["query"]
def test_search_by_vector_reraises_client_error(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
vector._client.search.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
vector.search_by_vector([0.1, 0.2])
def test_search_by_full_text_and_filters(opensearch_module):
vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module))
vector._client.search.return_value = {
"hits": {
"hits": [
{
"_source": {
opensearch_module.Field.METADATA_KEY: {"doc_id": "1"},
opensearch_module.Field.VECTOR: [0.1],
opensearch_module.Field.CONTENT_KEY: "matched text",
}
},
]
}
}
docs = vector.search_by_full_text("hello", document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].page_content == "matched text"
query = vector._client.search.call_args.kwargs["body"]
assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}]
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(opensearch_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(opensearch_module.redis_client, "set", MagicMock())
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module))
monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=1))
vector._client.indices.create.reset_mock()
vector.create_collection([[0.1, 0.2]])
vector._client.indices.create.assert_not_called()
monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=None))
vector._client.indices.exists.return_value = False
vector.create_collection([[0.1, 0.2]])
vector._client.indices.create.assert_called_once()
index_body = vector._client.indices.create.call_args.kwargs["body"]
assert index_body["mappings"]["properties"]["vector"]["dimension"] == 2
opensearch_module.redis_client.set.assert_called()
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch):
factory = opensearch_module.OpenSearchVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(opensearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_HOST", "localhost")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PORT", 9200)
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_SECURE", True)
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_VERIFY_CERTS", True)
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AUTH_METHOD", "basic")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_USER", "admin")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PASSWORD", "secret")
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_REGION", None)
monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_SERVICE", None)
with patch.object(opensearch_module, "OpenSearchVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,375 @@
import array
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import numpy
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_oracle_modules():
jieba = types.ModuleType("jieba")
jieba_posseg = types.ModuleType("jieba.posseg")
jieba_posseg.cut = MagicMock(return_value=[])
jieba.posseg = jieba_posseg
oracledb = types.ModuleType("oracledb")
oracledb_connection = types.ModuleType("oracledb.connection")
class Connection:
pass
oracledb_connection.Connection = Connection
oracledb.defaults = SimpleNamespace(fetch_lobs=True)
oracledb.DB_TYPE_VECTOR = object()
oracledb.create_pool = MagicMock(return_value=MagicMock(release=MagicMock()))
oracledb.connect = MagicMock()
return {
"jieba": jieba,
"jieba.posseg": jieba_posseg,
"oracledb": oracledb,
"oracledb.connection": oracledb_connection,
}
def _connection_with_cursor(cursor):
cursor_ctx = MagicMock()
cursor_ctx.__enter__.return_value = cursor
cursor_ctx.__exit__.return_value = None
connection = MagicMock()
connection.__enter__.return_value = connection
connection.__exit__.return_value = None
connection.cursor.return_value = cursor_ctx
return connection
@pytest.fixture
def oracle_module(monkeypatch):
for name, module in _build_fake_oracle_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.oracle.oraclevector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"user": "system",
"password": "oracle",
"dsn": "oracle:1521/freepdb1",
"is_autonomous": False,
}
values.update(overrides)
return module.OracleVectorConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("user", "", "config ORACLE_USER is required"),
("password", "", "config ORACLE_PASSWORD is required"),
("dsn", "", "config ORACLE_DSN is required"),
],
)
def test_oracle_config_validation_required_fields(oracle_module, field, value, message):
values = _config(oracle_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
oracle_module.OracleVectorConfig.model_validate(values)
def test_oracle_config_validation_autonomous_requirements(oracle_module):
with pytest.raises(ValidationError, match="config_dir is required"):
oracle_module.OracleVectorConfig.model_validate(
{"user": "u", "password": "p", "dsn": "d", "is_autonomous": True}
)
def test_init_and_get_type(oracle_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool))
vector = oracle_module.OracleVector("collection_1", _config(oracle_module))
assert vector.get_type() == "oracle"
assert vector.table_name == "embedding_collection_1"
assert vector.pool is pool
def test_numpy_converters_and_type_handlers(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
in_float64 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float64))
in_float32 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float32))
in_int8 = vector.numpy_converter_in(numpy.array([1], dtype=numpy.int8))
assert in_float64.typecode == "d"
assert in_float32.typecode == "f"
assert in_int8.typecode == "b"
cursor = MagicMock()
vector.input_type_handler(cursor, numpy.array([0.1], dtype=numpy.float32), 2)
cursor.var.assert_called_with(
oracle_module.oracledb.DB_TYPE_VECTOR,
arraysize=2,
inconverter=vector.numpy_converter_in,
)
metadata = SimpleNamespace(type_code=oracle_module.oracledb.DB_TYPE_VECTOR)
cursor.arraysize = 3
vector.output_type_handler(cursor, metadata)
cursor.var.assert_called_with(
metadata.type_code,
arraysize=3,
outconverter=vector.numpy_converter_out,
)
out_int8 = vector.numpy_converter_out(array.array("b", [1]))
assert out_int8.dtype == numpy.int8
out_float32 = vector.numpy_converter_out(array.array("f", [1.0]))
assert out_float32.dtype == numpy.float32
out_float64 = vector.numpy_converter_out(array.array("d", [1.0]))
assert out_float64.dtype == numpy.float64
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch):
connect = MagicMock(return_value="connection")
monkeypatch.setattr(oracle_module.oracledb, "connect", connect)
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.config = _config(oracle_module)
assert vector._get_connection() == "connection"
connect.assert_called_with(user="system", password="oracle", dsn="oracle:1521/freepdb1")
vector.config = _config(
oracle_module,
is_autonomous=True,
config_dir="/wallet",
wallet_location="/wallet",
wallet_password="pw",
)
vector._get_connection()
assert connect.call_args.kwargs["config_dir"] == "/wallet"
assert connect.call_args.kwargs["wallet_location"] == "/wallet"
def test_create_delegates_collection_and_insert(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector._create_collection = MagicMock()
vector.add_texts = MagicMock(return_value=["seg-1"])
docs = [Document(page_content="doc", metadata={"doc_id": "seg-1"})]
result = vector.create(docs, [[0.1, 0.2]])
assert result == ["seg-1"]
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector.input_type_handler = MagicMock()
vector.output_type_handler = MagicMock()
cursor = MagicMock()
cursor.execute.side_effect = [None, RuntimeError("insert failed")]
connection = _connection_with_cursor(cursor)
vector._get_connection = MagicMock(return_value=connection)
monkeypatch.setattr(oracle_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [
Document(page_content="a", metadata={"doc_id": "doc-a"}),
Document(page_content="b", metadata={"document_id": "doc-b"}),
SimpleNamespace(page_content="c", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
assert cursor.execute.call_count == 2
assert connection.commit.call_count >= 1
connection.close.assert_called()
def test_text_exists_and_get_by_ids(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector.pool = MagicMock()
cursor = MagicMock()
cursor.fetchone.return_value = ("id-1",)
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
assert vector.text_exists("id-1") is True
docs = vector.get_by_ids(["id-1", "id-2"])
assert len(docs) == 2
assert docs[0].page_content == "text-1"
vector.pool.release.assert_called_once()
assert vector.get_by_ids([]) == []
def test_delete_methods(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
vector.delete_by_ids([])
vector._get_connection.assert_not_called()
vector.delete_by_ids(["id-1", "id-2"])
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete()
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("DELETE FROM embedding_collection_1 WHERE id IN" in sql for sql in executed_sql)
assert any("JSON_VALUE(meta" in sql for sql in executed_sql)
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
def test_search_by_vector_with_threshold_and_filter(oracle_module):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector.input_type_handler = MagicMock()
vector.output_type_handler = MagicMock()
cursor = MagicMock()
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "doc-1", 0.1), ({"doc_id": "2"}, "doc-2", 0.8)])
connection = _connection_with_cursor(cursor)
vector._get_connection = MagicMock(return_value=connection)
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=0,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
sql = cursor.execute.call_args.args[0]
assert "fetch first 4 rows only" in sql
assert "JSON_VALUE(meta, '$.document_id') IN (:2, :3)" in sql
def _fake_nltk_module(*, missing_data=False):
nltk = types.ModuleType("nltk")
nltk_corpus = types.ModuleType("nltk.corpus")
class _Data:
@staticmethod
def find(_path):
if missing_data:
raise LookupError("missing")
return True
nltk.data = _Data()
nltk.word_tokenize = lambda text: text.split()
nltk_corpus.stopwords = SimpleNamespace(words=lambda _lang: ["and", "the"])
return nltk, nltk_corpus
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", [0.1, 0.2])])
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
monkeypatch.setattr(oracle_module.pseg, "cut", MagicMock(return_value=[("", "nr"), ("", "nr"), ("", "x")]))
zh_docs = vector.search_by_full_text("张三", top_k=2)
assert len(zh_docs) == 1
zh_params = cursor.execute.call_args.args[1]
assert zh_params["kk"] == "张三"
nltk, nltk_corpus = _fake_nltk_module(missing_data=False)
monkeypatch.setitem(sys.modules, "nltk", nltk)
monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus)
cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", [0.3, 0.4])])
en_docs = vector.search_by_full_text("alice and bob", top_k=-1, document_ids_filter=["d-1"])
assert len(en_docs) == 1
en_sql = cursor.execute.call_args.args[0]
en_params = cursor.execute.call_args.args[1]
assert "fetch first 5 rows only" in en_sql
assert "doc_id_0" in en_params
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector._get_connection = MagicMock()
empty_result = vector.search_by_full_text("")
assert empty_result[0].page_content == ""
nltk, nltk_corpus = _fake_nltk_module(missing_data=True)
monkeypatch.setitem(sys.modules, "nltk", nltk)
monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus)
with pytest.raises(LookupError, match="required NLTK data package"):
vector.search_by_full_text("english query")
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(oracle_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(oracle_module.redis_client, "set", MagicMock())
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector._collection_name = "collection_1"
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor))
monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(2)
cursor.execute.assert_not_called()
monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=None))
vector._create_collection(2)
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql)
assert any("CREATE INDEX IF NOT EXISTS idx_docs_embedding_collection_1" in sql for sql in executed_sql)
oracle_module.redis_client.set.assert_called_once()
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch):
factory = oracle_module.OracleVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(oracle_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_USER", "system")
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_PASSWORD", "oracle")
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_DSN", "oracle:1521/freepdb1")
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_CONFIG_DIR", None)
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_LOCATION", None)
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_PASSWORD", None)
monkeypatch.setattr(oracle_module.dify_config, "ORACLE_IS_AUTONOMOUS", False)
with patch.object(oracle_module, "OracleVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,317 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from sqlalchemy.types import UserDefinedType
from core.rag.models.document import Document
def _build_fake_pgvecto_modules():
pgvecto_rs = types.ModuleType("pgvecto_rs")
pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy")
class VECTOR(UserDefinedType):
def __init__(self, dim):
self.dim = dim
pgvecto_rs_sqlalchemy.VECTOR = VECTOR
return {
"pgvecto_rs": pgvecto_rs,
"pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy,
}
class _FakeSessionContext:
def __init__(self, calls, execute_results=None):
self.calls = calls
self.execute_results = execute_results or []
self.execute = MagicMock(side_effect=self._execute_side_effect)
self.commit = MagicMock()
def _execute_side_effect(self, *args, **kwargs):
self.calls.append((args, kwargs))
if self.execute_results:
return self.execute_results.pop(0)
return MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def _session_factory(calls, execute_results=None):
def _session(_client):
return _FakeSessionContext(calls=calls, execute_results=execute_results)
return _session
@pytest.fixture
def pgvecto_module(monkeypatch):
for name, module in _build_fake_pgvecto_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module
import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module
return importlib.reload(module), importlib.reload(collection_module)
def _config(module, **overrides):
values = {
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "secret",
"database": "postgres",
}
values.update(overrides)
return module.PgvectoRSConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config PGVECTO_RS_HOST is required"),
("port", 0, "config PGVECTO_RS_PORT is required"),
("user", "", "config PGVECTO_RS_USER is required"),
("password", "", "config PGVECTO_RS_PASSWORD is required"),
("database", "", "config PGVECTO_RS_DATABASE is required"),
],
)
def test_pgvecto_config_validation(pgvecto_module, field, value, message):
module, _ = pgvecto_module
values = _config(module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
module.PgvectoRSConfig.model_validate(values)
def test_collection_base_has_expected_annotations(pgvecto_module):
_, collection_module = pgvecto_module
annotations = collection_module.CollectionORM.__annotations__
assert {"id", "text", "meta", "vector"} <= set(annotations)
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "1"})]
vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == module.VectorType.PGVECTO_RS
module.create_engine.assert_called_once_with("postgresql+psycopg2://postgres:secret@localhost:5432/postgres")
assert any("CREATE EXTENSION IF NOT EXISTS vectors" in str(args[0]) for args, _ in session_calls)
vector.create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(module.redis_client, "set", MagicMock())
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection(3)
assert not any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls)
monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=None))
vector.create_collection(3)
assert any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls)
assert any("CREATE INDEX IF NOT EXISTS collection_1_embedding_index" in str(args[0]) for args, _ in session_calls)
module.redis_client.set.assert_called()
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
init_calls = []
runtime_calls = []
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
class _InsertBuilder:
def __init__(self, table):
self.table = table
def values(self, **kwargs):
return ("insert", kwargs)
monkeypatch.setattr(module, "insert", lambda table: _InsertBuilder(table))
monkeypatch.setattr(module, "uuid4", MagicMock(side_effect=["uuid-1", "uuid-2"]))
docs = [
Document(page_content="a", metadata={"doc_id": "1"}),
Document(page_content="b", metadata={"doc_id": "2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["uuid-1", "uuid-2"]
assert any(call[0][0][0] == "insert" for call in runtime_calls if call[0])
monkeypatch.setattr(
module,
"Session",
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
)
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
monkeypatch.setattr(
module,
"Session",
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [])]),
)
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
vector.delete_by_metadata_field("document_id", "doc-1")
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
runtime_calls.clear()
monkeypatch.setattr(
module,
"Session",
_session_factory(
runtime_calls,
execute_results=[
SimpleNamespace(fetchall=lambda: [("row-id-1",)]),
MagicMock(),
],
),
)
vector.delete_by_ids(["doc-1"])
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
runtime_calls.clear()
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
vector.delete()
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
init_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
runtime_calls = []
monkeypatch.setattr(
module,
"Session",
_session_factory(
runtime_calls,
execute_results=[
SimpleNamespace(fetchall=lambda: [("id-1",)]),
SimpleNamespace(fetchall=lambda: []),
],
),
)
assert vector.text_exists("doc-1") is True
assert vector.text_exists("doc-1") is False
class _DistanceExpr:
def label(self, _name):
return self
class _VectorColumn:
def op(self, _operator, return_type=None):
def _call(_query_vector):
return _DistanceExpr()
return _call
class _MetaFilter:
def in_(self, values):
return ("in", values)
class _MetaColumn:
def __getitem__(self, _item):
return _MetaFilter()
class _Stmt:
def __init__(self):
self.where_called = False
def limit(self, _value):
return self
def order_by(self, _value):
return self
def where(self, _value):
self.where_called = True
return self
stmt = _Stmt()
monkeypatch.setattr(module, "select", lambda *_args: stmt)
vector._table = SimpleNamespace(vector=_VectorColumn(), meta=_MetaColumn())
rows = [
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
]
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
assert stmt.where_called is True
assert vector.search_by_full_text("hello") == []
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
factory = module.PGVectoRSFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_HOST", "localhost")
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PORT", 5432)
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_USER", "postgres")
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PASSWORD", "secret")
monkeypatch.setattr(module.dify_config, "PGVECTO_RS_DATABASE", "postgres")
embeddings = MagicMock()
embeddings.embed_query.return_value = [0.1, 0.2, 0.3]
with patch.object(module, "PGVectoRS", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=embeddings)
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=embeddings)
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -1,16 +1,19 @@
import unittest from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module
from core.rag.datasource.vdb.pgvector.pgvector import ( from core.rag.datasource.vdb.pgvector.pgvector import (
PGVector, PGVector,
PGVectorConfig, PGVectorConfig,
) )
from core.rag.models.document import Document
class TestPGVector(unittest.TestCase): class TestPGVector:
def setUp(self): def setup_method(self, method):
self.config = PGVectorConfig( self.config = PGVectorConfig(
host="localhost", host="localhost",
port=5432, port=5432,
@ -323,5 +326,172 @@ def test_config_validation_parametrized(invalid_config_override):
PGVectorConfig(**config) PGVectorConfig(**config)
if __name__ == "__main__": def test_create_delegates_collection_creation_and_insert():
unittest.main() vector = PGVector.__new__(PGVector)
vector._create_collection = MagicMock()
vector.add_texts = MagicMock(return_value=["doc-a"])
docs = [Document(page_content="hello", metadata={"doc_id": "doc-a"})]
result = vector.create(docs, [[0.1, 0.2]])
assert result == ["doc-a"]
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch):
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(pgvector_module.uuid, "uuid4", lambda: "generated-uuid")
execute_values = MagicMock()
monkeypatch.setattr(pgvector_module.psycopg2.extras, "execute_values", execute_values)
docs = [
Document(page_content="a", metadata={"doc_id": "doc-a"}),
Document(page_content="b", metadata={"document_id": "doc-b"}),
SimpleNamespace(page_content="c", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
execute_values.assert_called_once()
def test_text_get_and_delete_methods():
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.fetchone.return_value = ("id-1",)
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
assert vector.text_exists("id-1") is True
docs = vector.get_by_ids(["id-1", "id-2"])
assert len(docs) == 2
assert docs[0].page_content == "text-1"
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete()
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("meta->>%s = %s" in sql for sql in executed_sql)
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch):
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
vector.delete_by_ids([])
cursor.execute.assert_not_called()
class _UndefinedTableError(Exception):
pass
monkeypatch.setattr(pgvector_module.psycopg2.errors, "UndefinedTable", _UndefinedTableError)
cursor.execute.side_effect = _UndefinedTableError("missing")
vector.delete_by_ids(["doc-1"])
cursor.execute.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
vector.delete_by_ids(["doc-1"])
def test_search_by_vector_supports_filter_and_threshold():
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.1), ({"doc_id": "2"}, "text-2", 0.8)])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_vector([0.1], top_k=0)
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
sql = cursor.execute.call_args.args[0]
assert "meta->>'document_id' in ('d-1')" in sql
def test_search_by_full_text_branches_for_bigm_and_standard():
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.7)])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_full_text("hello", top_k=0)
vector.pg_bigm = False
docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.7)
standard_sql = cursor.execute.call_args.args[0]
assert "to_tsvector(text) @@ plainto_tsquery(%s)" in standard_sql
cursor.execute.reset_mock()
cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", 0.6)])
vector.pg_bigm = True
vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-2"])
assert "SET pg_bigm.similarity_limit TO 0.000001" in cursor.execute.call_args_list[0].args[0]
assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0]
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch):
factory = pgvector_module.PGVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(pgvector_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_HOST", "localhost")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PORT", 5432)
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_USER", "postgres")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PASSWORD", "secret")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_DATABASE", "postgres")
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MIN_CONNECTION", 1)
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MAX_CONNECTION", 5)
monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PG_BIGM", False)
with patch.object(pgvector_module, "PGVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,269 @@
import importlib
import sys
import types
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_psycopg2_modules():
psycopg2 = types.ModuleType("psycopg2")
psycopg2.__path__ = []
psycopg2_extras = types.ModuleType("psycopg2.extras")
psycopg2_pool = types.ModuleType("psycopg2.pool")
class SimpleConnectionPool:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.getconn = MagicMock()
self.putconn = MagicMock()
psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool
psycopg2_extras.execute_values = MagicMock()
psycopg2.pool = psycopg2_pool
psycopg2.extras = psycopg2_extras
return {
"psycopg2": psycopg2,
"psycopg2.pool": psycopg2_pool,
"psycopg2.extras": psycopg2_extras,
}
@pytest.fixture
def vastbase_module(monkeypatch):
for name, module in _build_fake_psycopg2_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module
return importlib.reload(module)
def _config(module):
return module.VastbaseVectorConfig(
host="localhost",
port=5432,
user="dify",
password="secret",
database="dify",
min_connection=1,
max_connection=5,
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config VASTBASE_HOST is required"),
("port", 0, "config VASTBASE_PORT is required"),
("user", "", "config VASTBASE_USER is required"),
("password", "", "config VASTBASE_PASSWORD is required"),
("database", "", "config VASTBASE_DATABASE is required"),
("min_connection", 0, "config VASTBASE_MIN_CONNECTION is required"),
("max_connection", 0, "config VASTBASE_MAX_CONNECTION is required"),
],
)
def test_vastbase_config_validation(vastbase_module, field, value, message):
values = _config(vastbase_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
vastbase_module.VastbaseVectorConfig.model_validate(values)
def test_vastbase_config_rejects_invalid_connection_window(vastbase_module):
with pytest.raises(ValidationError, match="VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION"):
vastbase_module.VastbaseVectorConfig.model_validate(
{
"host": "localhost",
"port": 5432,
"user": "dify",
"password": "secret",
"database": "dify",
"min_connection": 6,
"max_connection": 5,
}
)
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
pool = MagicMock()
monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
conn = MagicMock()
cur = MagicMock()
pool.getconn.return_value = conn
conn.cursor.return_value = cur
vector = vastbase_module.VastbaseVector("collection_1", _config(vastbase_module))
assert vector.get_type() == "vastbase"
assert vector.table_name == "embedding_collection_1"
with vector._get_cursor() as got_cur:
assert got_cur is cur
cur.close.assert_called_once()
conn.commit.assert_called_once()
pool.putconn.assert_called_once_with(conn)
def test_create_and_add_texts(vastbase_module, monkeypatch):
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector.table_name = "embedding_collection_1"
vector._create_collection = MagicMock()
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(vastbase_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [
Document(page_content="a", metadata={"doc_id": "doc-a"}),
Document(page_content="b", metadata={"document_id": "doc-b"}),
SimpleNamespace(page_content="c", metadata=None),
]
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
vastbase_module.psycopg2.extras.execute_values.assert_called_once()
vector.add_texts = MagicMock(return_value=["doc-a"])
result = vector.create(docs, [[0.1], [0.2], [0.3]])
vector._create_collection.assert_called_once_with(1)
assert result == ["doc-a"]
def test_text_get_delete_and_metadata_methods(vastbase_module):
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.fetchone.return_value = ("id-1",)
cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")])
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
assert vector.text_exists("id-1") is True
docs = vector.get_by_ids(["id-1", "id-2"])
assert len(docs) == 2
assert docs[0].page_content == "text-1"
vector.delete_by_ids([])
vector.delete_by_ids(["id-1"])
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete()
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in sql for sql in executed_sql)
assert any("meta->>%s = %s" in sql for sql in executed_sql)
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
def test_search_by_vector_and_full_text(vastbase_module):
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
cursor.__iter__.return_value = iter(
[
({"doc_id": "1"}, "text-1", 0.1),
({"doc_id": "2"}, "text-2", 0.8),
]
)
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_vector([0.1, 0.2], top_k=0)
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
with pytest.raises(ValueError, match="top_k must be a positive integer"):
vector.search_by_full_text("hello", top_k=0)
cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.7)])
full_docs = vector.search_by_full_text("hello world", top_k=2)
assert len(full_docs) == 1
assert full_docs[0].page_content == "full-text"
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(vastbase_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(vastbase_module.redis_client, "set", MagicMock())
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector._collection_name = "collection_1"
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@contextmanager
def _cursor_ctx():
yield cursor
vector._get_cursor = _cursor_ctx
monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(3)
cursor.execute.assert_not_called()
monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=None))
vector._create_collection(17000)
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql)
assert all("embedding_cosine_v1_idx" not in sql for sql in executed_sql)
cursor.execute.reset_mock()
vector._create_collection(3)
executed_sql = [call.args[0] for call in cursor.execute.call_args_list]
assert any("embedding_cosine_v1_idx" in sql for sql in executed_sql)
vastbase_module.redis_client.set.assert_called()
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch):
factory = vastbase_module.VastbaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(vastbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_HOST", "localhost")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PORT", 5432)
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_USER", "dify")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PASSWORD", "secret")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_DATABASE", "dify")
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MIN_CONNECTION", 1)
monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MAX_CONNECTION", 5)
with patch.object(vastbase_module, "VastbaseVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,328 @@
import importlib
import os
import sys
import types
from collections import UserDict
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_qdrant_modules():
qdrant_client = types.ModuleType("qdrant_client")
qdrant_http = types.ModuleType("qdrant_client.http")
qdrant_http_models = types.ModuleType("qdrant_client.http.models")
qdrant_http_exceptions = types.ModuleType("qdrant_client.http.exceptions")
qdrant_local_pkg = types.ModuleType("qdrant_client.local")
qdrant_local_mod = types.ModuleType("qdrant_client.local.qdrant_local")
class UnexpectedResponseError(Exception):
def __init__(self, status_code):
super().__init__(f"status={status_code}")
self.status_code = status_code
class FilterSelector:
def __init__(self, filter):
self.filter = filter
class HnswConfigDiff:
def __init__(self, **kwargs):
self.kwargs = kwargs
class TextIndexParams:
def __init__(self, **kwargs):
self.kwargs = kwargs
class VectorParams:
def __init__(self, **kwargs):
self.kwargs = kwargs
class PointStruct:
def __init__(self, **kwargs):
self.id = kwargs["id"]
self.vector = kwargs["vector"]
self.payload = kwargs["payload"]
class Filter:
def __init__(self, must=None):
self.must = must or []
class FieldCondition:
def __init__(self, key, match):
self.key = key
self.match = match
class MatchValue:
def __init__(self, value):
self.value = value
class MatchAny:
def __init__(self, any):
self.any = any
class MatchText:
def __init__(self, text):
self.text = text
class _Distance(UserDict):
def __getitem__(self, key):
return key
class QdrantClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.get_collections = MagicMock(return_value=SimpleNamespace(collections=[]))
self.create_collection = MagicMock()
self.create_payload_index = MagicMock()
self.upsert = MagicMock()
self.delete = MagicMock()
self.delete_collection = MagicMock()
self.retrieve = MagicMock(return_value=[])
self.search = MagicMock(return_value=[])
self.scroll = MagicMock(return_value=([], None))
class QdrantLocal(QdrantClient):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._load = MagicMock()
qdrant_client.QdrantClient = QdrantClient
qdrant_http_models.FilterSelector = FilterSelector
qdrant_http_models.HnswConfigDiff = HnswConfigDiff
qdrant_http_models.PayloadSchemaType = SimpleNamespace(KEYWORD="KEYWORD")
qdrant_http_models.TextIndexParams = TextIndexParams
qdrant_http_models.TextIndexType = SimpleNamespace(TEXT="TEXT")
qdrant_http_models.TokenizerType = SimpleNamespace(MULTILINGUAL="MULTILINGUAL")
qdrant_http_models.VectorParams = VectorParams
qdrant_http_models.Distance = _Distance()
qdrant_http_models.PointStruct = PointStruct
qdrant_http_models.Filter = Filter
qdrant_http_models.FieldCondition = FieldCondition
qdrant_http_models.MatchValue = MatchValue
qdrant_http_models.MatchAny = MatchAny
qdrant_http_models.MatchText = MatchText
qdrant_http_exceptions.UnexpectedResponse = UnexpectedResponseError
qdrant_http.models = qdrant_http_models
qdrant_local_mod.QdrantLocal = QdrantLocal
qdrant_local_pkg.qdrant_local = qdrant_local_mod
return {
"qdrant_client": qdrant_client,
"qdrant_client.http": qdrant_http,
"qdrant_client.http.models": qdrant_http_models,
"qdrant_client.http.exceptions": qdrant_http_exceptions,
"qdrant_client.local": qdrant_local_pkg,
"qdrant_client.local.qdrant_local": qdrant_local_mod,
}
@pytest.fixture
def qdrant_module(monkeypatch):
for name, module in _build_fake_qdrant_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.qdrant.qdrant_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"endpoint": "http://localhost:6333",
"api_key": "api-key",
"timeout": 20,
"root_path": "/tmp",
"grpc_port": 6334,
"prefer_grpc": False,
"replication_factor": 1,
"write_consistency_factor": 1,
}
values.update(overrides)
return module.QdrantConfig.model_validate(values)
def test_qdrant_config_to_params(qdrant_module):
url_params = _config(qdrant_module).to_qdrant_params().model_dump()
assert url_params["url"] == "http://localhost:6333"
assert url_params["verify"] is False
path_config = _config(qdrant_module, endpoint="path:storage")
assert path_config.to_qdrant_params().path == os.path.join("/tmp", "storage")
with pytest.raises(ValueError, match="Root path is not set"):
_config(qdrant_module, endpoint="path:storage", root_path=None).to_qdrant_params()
def test_init_and_basic_behaviour(qdrant_module):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
assert vector.get_type() == qdrant_module.VectorType.QDRANT
assert vector.to_index_struct()["vector_store"]["class_prefix"] == "collection_1"
docs = [Document(page_content="a", metadata={"doc_id": "a"})]
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
vector.create(docs, [[0.1]])
vector.create_collection.assert_called_once_with("collection_1", 1)
vector.add_texts.assert_called_once()
def test_create_collection_and_add_texts(qdrant_module, monkeypatch):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(qdrant_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(qdrant_module.redis_client, "set", MagicMock())
monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=1))
vector.create_collection("collection_1", 3)
vector._client.create_collection.assert_not_called()
monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=None))
vector._client.get_collections.return_value = SimpleNamespace(collections=[])
vector.create_collection("collection_1", 3)
vector._client.create_collection.assert_called_once()
assert vector._client.create_payload_index.call_count == 4
qdrant_module.redis_client.set.assert_called_once()
# add_texts and generated batches
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert vector._client.upsert.call_count == 1
payloads = qdrant_module.QdrantVector._build_payloads(
["a"], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id"
)
assert payloads[0]["group_id"] == "g1"
with pytest.raises(ValueError, match="At least one of the texts is None"):
qdrant_module.QdrantVector._build_payloads(
[None], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id"
)
def test_delete_and_exists_paths(qdrant_module):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
unexpected = sys.modules["qdrant_client.http.exceptions"].UnexpectedResponse
vector._client.delete.side_effect = unexpected(404)
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.delete.side_effect = None
vector._client.delete.side_effect = unexpected(500)
with pytest.raises(unexpected):
vector.delete_by_metadata_field("document_id", "doc-1")
vector._client.delete.side_effect = None
vector._client.delete.side_effect = unexpected(404)
vector.delete()
vector._client.delete.side_effect = unexpected(500)
with pytest.raises(unexpected):
vector.delete()
vector._client.delete.side_effect = None
vector._client.delete.side_effect = unexpected(404)
vector.delete_by_ids(["doc-1"])
vector._client.delete.side_effect = unexpected(500)
with pytest.raises(unexpected):
vector.delete_by_ids(["doc-1"])
vector._client.delete.side_effect = None
vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="other")])
assert vector.text_exists("id-1") is False
vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="collection_1")])
vector._client.retrieve.return_value = [{"id": "id-1"}]
assert vector.text_exists("id-1") is True
def test_search_and_helper_methods(qdrant_module):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
assert vector.search_by_vector([0.1], score_threshold=1.0) == []
vector._client.search.return_value = [
SimpleNamespace(payload=None, score=0.9, vector=[0.1]),
SimpleNamespace(payload={"metadata": {"doc_id": "1"}, "page_content": "doc-a"}, score=0.8, vector=[0.1]),
]
docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.8)
# full text search: keyword split, dedup and top_k limit
scroll_results = [
(
[
SimpleNamespace(id="p1", payload={"page_content": "doc-1", "metadata": {"doc_id": "1"}}, vector=[0.1]),
SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]),
],
None,
),
(
[
SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]),
],
None,
),
]
vector._client.scroll.side_effect = scroll_results
docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"])
assert len(docs) == 2
assert vector.search_by_full_text(" ", top_k=2) == []
local_client = qdrant_module.QdrantLocal()
vector._client = local_client
vector._reload_if_needed()
local_client._load.assert_called_once()
doc = vector._document_from_scored_point(
SimpleNamespace(payload={"page_content": "doc", "metadata": {"doc_id": "1"}}, vector=[0.1]),
"page_content",
"metadata",
)
assert doc.page_content == "doc"
def test_qdrant_factory_paths(qdrant_module, monkeypatch):
factory = qdrant_module.QdrantVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",
tenant_id="tenant-1",
collection_binding_id=None,
index_struct_dict=None,
index_struct=None,
)
monkeypatch.setattr(qdrant_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(qdrant_module, "current_app", SimpleNamespace(config=SimpleNamespace(root_path="/root")))
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_URL", "http://localhost:6333")
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_API_KEY", "api-key")
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_CLIENT_TIMEOUT", 20)
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_PORT", 6334)
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_ENABLED", False)
monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_REPLICATION_FACTOR", 1)
with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls:
result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert result == "vector"
assert vector_cls.call_args.kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset.index_struct is not None
# collection binding lookup path
dataset.collection_binding_id = "binding-1"
dataset.index_struct_dict = {"vector_store": {"class_prefix": "existing"}}
monkeypatch.setattr(qdrant_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
qdrant_module.db.session.scalars = MagicMock(
return_value=SimpleNamespace(one_or_none=lambda: SimpleNamespace(collection_name="BOUND_COLLECTION"))
)
with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls:
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())
assert vector_cls.call_args.kwargs["collection_name"] == "BOUND_COLLECTION"
qdrant_module.db.session.scalars = MagicMock(return_value=SimpleNamespace(one_or_none=lambda: None))
with pytest.raises(ValueError, match="Dataset Collection Bindings does not exist"):
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())

View File

@ -0,0 +1,303 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from sqlalchemy.types import UserDefinedType
from core.rag.models.document import Document
def _build_fake_relyt_modules():
pgvecto_rs = types.ModuleType("pgvecto_rs")
pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy")
class VECTOR(UserDefinedType):
def __init__(self, dim):
self.dim = dim
pgvecto_rs_sqlalchemy.VECTOR = VECTOR
return {
"pgvecto_rs": pgvecto_rs,
"pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy,
}
class _FakeSession:
def __init__(self, execute_result=None):
self.execute_result = execute_result or MagicMock(fetchall=lambda: [])
self.execute = MagicMock(return_value=self.execute_result)
self.commit = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
@pytest.fixture
def relyt_module(monkeypatch):
for name, module in _build_fake_relyt_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.relyt.relyt_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "secret",
"database": "relyt",
}
values.update(overrides)
return module.RelytConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config RELYT_HOST is required"),
("port", 0, "config RELYT_PORT is required"),
("user", "", "config RELYT_USER is required"),
("password", "", "config RELYT_PASSWORD is required"),
("database", "", "config RELYT_DATABASE is required"),
],
)
def test_relyt_config_validation(relyt_module, field, value, message):
values = _config(relyt_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
relyt_module.RelytConfig.model_validate(values)
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
engine = MagicMock()
monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine))
vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1")
vector.create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})]
vector.create(docs, [[0.1, 0.2]])
assert vector.get_type() == relyt_module.VectorType.RELYT
assert vector._url == "postgresql+psycopg2://postgres:secret@localhost:5432/relyt"
assert vector.embedding_dimension == 2
vector.create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(relyt_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(relyt_module.redis_client, "set", MagicMock())
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
vector.create_collection(3)
session.execute.assert_not_called()
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
vector.create_collection(3)
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
assert any("CREATE TABLE IF NOT EXISTS" in sql for sql in executed_sql)
assert any("CREATE INDEX" in sql for sql in executed_sql)
relyt_module.redis_client.set.assert_called_once()
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector._group_id = "group-1"
vector.client = MagicMock()
begin_ctx = MagicMock()
begin_ctx.__enter__.return_value = None
begin_ctx.__exit__.return_value = None
conn = MagicMock()
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
conn.begin.return_value = begin_ctx
vector.client.connect.return_value = conn
monkeypatch.setattr(relyt_module.uuid, "uuid1", MagicMock(side_effect=["id-1", "id-2"]))
docs = [
Document(page_content="a", metadata={"doc_id": "d-1"}),
Document(page_content="b", metadata={"doc_id": "d-2"}),
]
ids = vector.add_texts(docs, [[0.1], [0.2]])
assert ids == ["id-1", "id-2"]
assert conn.execute.call_count >= 1
first_insert_values = conn.execute.call_args.args[0].compile().params
assert "group_id" in str(first_insert_values)
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-a",), ("id-b",)]))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a", "id-b"]
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: []))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None
# 1. delete_by_uuids: success and connect error
def test_delete_by_uuids_success_and_connect_error(relyt_module):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
with pytest.raises(ValueError, match="No ids provided"):
vector.delete_by_uuids(None)
conn = MagicMock()
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
begin_ctx = MagicMock()
begin_ctx.__enter__.return_value = None
begin_ctx.__exit__.return_value = None
conn.begin.return_value = begin_ctx
vector.client.connect.return_value = conn
assert vector.delete_by_uuids(["id-1"]) is True
vector.client.connect.side_effect = RuntimeError("boom")
assert vector.delete_by_uuids(["id-1"]) is False
# 2. delete_by_metadata_field calls delete_by_uuids
def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
vector.delete_by_uuids = MagicMock(return_value=True)
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_uuids.assert_called_once_with(["id-1"])
# 3. delete_by_ids translates to uuids
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("uuid-1",), ("uuid-2",)]))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
vector.delete_by_uuids = MagicMock(return_value=True)
vector.delete_by_ids(["doc-1", "doc-2"])
vector.delete_by_uuids.assert_called_once_with(["uuid-1", "uuid-2"])
# 4. text_exists True
def test_text_exists_true(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-1",)]))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
assert vector.text_exists("doc-1") is True
# 5. text_exists False
def test_text_exists_false(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession(execute_result=MagicMock(fetchall=lambda: []))
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
assert vector.text_exists("doc-1") is False
# 6. similarity_search_with_score_by_vector returns Documents and scores
def test_similarity_search_with_score_by_vector(relyt_module):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
result_rows = [
SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}, distance=0.1),
SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.8),
]
conn = MagicMock()
conn.__enter__.return_value = conn
conn.__exit__.return_value = None
conn.execute.return_value.fetchall.return_value = result_rows
vector.client.connect.return_value = conn
similarities = vector.similarity_search_with_score_by_vector([0.1, 0.2], k=2, filter={"document_id": ["d-1"]})
assert len(similarities) == 2
assert similarities[0][0].page_content == "doc-a"
# 7. search_by_vector filters by score and ids
def test_search_by_vector_filters_by_score_and_ids(relyt_module):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
vector.similarity_search_with_score_by_vector = MagicMock(
return_value=[
(Document(page_content="a", metadata={"doc_id": "1"}), 0.1),
(Document(page_content="b", metadata={}), 0.9),
]
)
docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1
assert vector.search_by_full_text("query") == []
# 8. delete commits session
def test_delete_commits_session(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
vector.delete()
session.commit.assert_called_once()
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
factory = relyt_module.RelytVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(relyt_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(relyt_module.dify_config, "RELYT_HOST", "localhost")
monkeypatch.setattr(relyt_module.dify_config, "RELYT_PORT", 5432)
monkeypatch.setattr(relyt_module.dify_config, "RELYT_USER", "postgres")
monkeypatch.setattr(relyt_module.dify_config, "RELYT_PASSWORD", "secret")
monkeypatch.setattr(relyt_module.dify_config, "RELYT_DATABASE", "relyt")
with patch.object(relyt_module, "RelytVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,316 @@
import importlib
import json
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_tablestore_module():
tablestore = types.ModuleType("tablestore")
class _BatchGetRowRequest:
def __init__(self):
self.items = []
def add(self, item):
self.items.append(item)
class _TableInBatchGetRowItem:
def __init__(self, table_name, rows_to_get, columns_to_get, _unused, _ver):
self.table_name = table_name
self.rows_to_get = rows_to_get
self.columns_to_get = columns_to_get
class _Row:
def __init__(self, primary_key, attribute_columns=None):
self.primary_key = primary_key
self.attribute_columns = attribute_columns or []
class _Client:
def __init__(self, *_args):
self.list_table = MagicMock(return_value=[])
self.create_table = MagicMock()
self.list_search_index = MagicMock(return_value=[])
self.create_search_index = MagicMock()
self.delete_search_index = MagicMock()
self.delete_table = MagicMock()
self.put_row = MagicMock()
self.delete_row = MagicMock()
self.get_row = MagicMock(return_value=(None, None, None))
self.batch_get_row = MagicMock()
self.search = MagicMock()
tablestore.OTSClient = _Client
tablestore.BatchGetRowRequest = _BatchGetRowRequest
tablestore.TableInBatchGetRowItem = _TableInBatchGetRowItem
tablestore.Row = _Row
tablestore.TableMeta = lambda name, schema: ("table_meta", name, schema)
tablestore.TableOptions = lambda: ("table_options",)
tablestore.CapacityUnit = lambda read, write: ("capacity", read, write)
tablestore.ReservedThroughput = lambda cap: ("reserved", cap)
tablestore.FieldSchema = lambda *args, **kwargs: ("field", args, kwargs)
tablestore.VectorOptions = lambda **kwargs: ("vector_options", kwargs)
tablestore.SearchIndexMeta = lambda field_schemas: ("search_index_meta", field_schemas)
tablestore.SearchQuery = lambda query, **kwargs: SimpleNamespace(query=query, **kwargs)
tablestore.TermQuery = lambda key, value: ("term_query", key, value)
tablestore.ColumnsToGet = lambda **kwargs: ("columns_to_get", kwargs)
tablestore.KnnVectorQuery = lambda **kwargs: SimpleNamespace(**kwargs)
tablestore.TermsQuery = lambda key, values: ("terms_query", key, values)
tablestore.Sort = lambda **kwargs: ("sort", kwargs)
tablestore.ScoreSort = lambda **kwargs: ("score_sort", kwargs)
tablestore.BoolQuery = lambda **kwargs: SimpleNamespace(**kwargs)
tablestore.MatchQuery = lambda **kwargs: ("match_query", kwargs)
tablestore.FieldType = SimpleNamespace(TEXT="TEXT", VECTOR="VECTOR", KEYWORD="KEYWORD")
tablestore.AnalyzerType = SimpleNamespace(MAXWORD="MAXWORD")
tablestore.VectorDataType = SimpleNamespace(VD_FLOAT_32="VD_FLOAT_32")
tablestore.VectorMetricType = SimpleNamespace(VM_COSINE="VM_COSINE")
tablestore.ColumnReturnType = SimpleNamespace(SPECIFIED="SPECIFIED", ALL_FROM_INDEX="ALL_FROM_INDEX")
tablestore.SortOrder = SimpleNamespace(DESC="DESC")
return tablestore
@pytest.fixture
def tablestore_module(monkeypatch):
fake_module = _build_fake_tablestore_module()
monkeypatch.setitem(sys.modules, "tablestore", fake_module)
import core.rag.datasource.vdb.tablestore.tablestore_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"access_key_id": "ak",
"access_key_secret": "sk",
"instance_name": "instance",
"endpoint": "endpoint",
"normalize_full_text_bm25_score": False,
}
values.update(overrides)
return module.TableStoreConfig.model_validate(values)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("access_key_id", "", "config ACCESS_KEY_ID is required"),
("access_key_secret", "", "config ACCESS_KEY_SECRET is required"),
("instance_name", "", "config INSTANCE_NAME is required"),
("endpoint", "", "config ENDPOINT is required"),
],
)
def test_tablestore_config_validation(tablestore_module, field, value, message):
values = _config(tablestore_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
tablestore_module.TableStoreConfig.model_validate(values)
def test_init_and_basic_delegation(tablestore_module):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
assert vector.get_type() == tablestore_module.VectorType.TABLESTORE
assert vector._table_name == "collection_1"
assert vector._index_name == "collection_1_idx"
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "d-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(documents=docs, embeddings=[[0.1, 0.2]])
vector.create_collection([[0.1, 0.2]])
assert vector._create_collection.call_count == 2
def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
# get_by_ids
ok_item = SimpleNamespace(
is_ok=True,
row=SimpleNamespace(
attribute_columns=[("metadata", json.dumps({"doc_id": "1"}), None), ("page_content", "text-1", None)]
),
)
fail_item = SimpleNamespace(is_ok=False, row=None)
batch_resp = SimpleNamespace(get_result_by_table=lambda _table: [ok_item, fail_item])
vector._tablestore_client.batch_get_row.return_value = batch_resp
docs = vector.get_by_ids(["id-1"])
assert len(docs) == 1
assert docs[0].page_content == "text-1"
# text_exists
vector._tablestore_client.get_row.return_value = (None, object(), None)
assert vector.text_exists("id-1") is True
vector._tablestore_client.get_row.return_value = (None, None, None)
assert vector.text_exists("id-1") is False
# delete wrappers
vector._delete_row = MagicMock()
vector.delete_by_ids([])
vector._delete_row.assert_not_called()
vector.delete_by_ids(["id-1", "id-2"])
assert vector._delete_row.call_count == 2
vector._search_by_metadata = MagicMock(return_value=["id-a"])
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a"]
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("document_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-a"])
vector._search_by_vector = MagicMock(return_value=["vec-doc"])
vector._search_by_full_text = MagicMock(return_value=["fts-doc"])
assert vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) == ["vec-doc"]
assert vector.search_by_full_text("query", top_k=2, score_threshold=0.3, document_ids_filter=["d-1"]) == ["fts-doc"]
vector._delete_table_if_exist = MagicMock()
vector.delete()
vector._delete_table_if_exist.assert_called_once()
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tablestore_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tablestore_module.redis_client, "set", MagicMock())
monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=1))
vector._create_table_if_not_exist = MagicMock()
vector._create_search_index_if_not_exist = MagicMock()
vector._create_collection(3)
vector._create_table_if_not_exist.assert_not_called()
monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=None))
vector._create_collection(3)
vector._create_table_if_not_exist.assert_called_once()
vector._create_search_index_if_not_exist.assert_called_once_with(3)
tablestore_module.redis_client.set.assert_called_once()
vector = tablestore_module.TableStoreVector("collection_2", _config(tablestore_module))
vector._tablestore_client.list_table.return_value = ["collection_2"]
assert vector._create_table_if_not_exist() is None
vector._tablestore_client.list_table.return_value = []
vector._create_table_if_not_exist()
vector._tablestore_client.create_table.assert_called_once()
vector._tablestore_client.list_search_index.return_value = [("collection_2", "collection_2_idx")]
assert vector._create_search_index_if_not_exist(3) is None
vector._tablestore_client.list_search_index.return_value = []
vector._create_search_index_if_not_exist(3)
vector._tablestore_client.create_search_index.assert_called_once()
vector._tablestore_client.list_search_index.return_value = [("collection_2", "idx_a"), ("collection_2", "idx_b")]
vector._delete_table_if_exist()
assert vector._tablestore_client.delete_search_index.call_count == 2
vector._tablestore_client.delete_table.assert_called_once_with("collection_2")
vector._delete_search_index()
vector._tablestore_client.delete_search_index.assert_called_with("collection_2", "collection_2_idx")
def test_write_row_and_search_helpers(tablestore_module):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
vector._write_row(
"id-1",
{
"page_content": "hello",
"vector": [0.1, 0.2],
"metadata": {"doc_id": "d-1", "document_id": "doc-1"},
},
)
put_row_call = vector._tablestore_client.put_row.call_args
assert put_row_call.args[0] == "collection_1"
attrs = put_row_call.args[1].attribute_columns
assert any(item[0] == "metadata_tags" for item in attrs)
vector._delete_row("id-1")
vector._tablestore_client.delete_row.assert_called_once()
# metadata search pagination
first_page = SimpleNamespace(rows=[[(("id", "row-1"),)]], next_token=b"next")
second_page = SimpleNamespace(rows=[[(("id", "row-2"),)]], next_token=b"")
vector._tablestore_client.search.side_effect = [first_page, second_page]
ids = vector._search_by_metadata("document_id", "doc-1")
assert ids == ["row-1", "row-2"]
vector._tablestore_client.search.side_effect = None
# vector search
hit1 = SimpleNamespace(
score=0.9,
row=(
None,
[("page_content", "doc-a"), ("metadata", json.dumps({"doc_id": "1"})), ("vector", json.dumps([0.1]))],
),
)
hit2 = SimpleNamespace(
score=0.2,
row=(
None,
[("page_content", "doc-b"), ("metadata", json.dumps({"doc_id": "2"})), ("vector", json.dumps([0.2]))],
),
)
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit1, hit2])
docs = vector._search_by_vector([0.1], document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.5)
assert len(docs) == 1
assert docs[0].metadata["score"] == pytest.approx(0.9)
assert tablestore_module.TableStoreVector._normalize_score_exp_decay(0) == pytest.approx(0.0)
assert tablestore_module.TableStoreVector._normalize_score_exp_decay(100) <= 1.0
# full text search with and without normalized score filter
vector._normalize_full_text_bm25_score = True
hit3 = SimpleNamespace(
score=10.0, row=(None, [("page_content", "doc-c"), ("metadata", json.dumps({"doc_id": "3"}))])
)
hit4 = SimpleNamespace(
score=0.1, row=(None, [("page_content", "doc-d"), ("metadata", json.dumps({"doc_id": "4"}))])
)
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3, hit4])
docs = vector._search_by_full_text("query", document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.2)
assert len(docs) == 1
assert "score" in docs[0].metadata
vector._normalize_full_text_bm25_score = False
vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3])
docs = vector._search_by_full_text("query", document_ids_filter=None, top_k=2, score_threshold=0.0)
assert len(docs) == 1
assert "score" not in docs[0].metadata
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch):
factory = tablestore_module.TableStoreVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(tablestore_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ENDPOINT", "endpoint")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_INSTANCE_NAME", "instance")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_ID", "ak")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_SECRET", "sk")
monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE", True)
with patch.object(tablestore_module, "TableStoreVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,309 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_tencent_modules():
tcvdb_text = types.ModuleType("tcvdb_text")
tcvdb_text_encoder = types.ModuleType("tcvdb_text.encoder")
tcvectordb = types.ModuleType("tcvectordb")
tcvectordb_model = types.ModuleType("tcvectordb.model")
tcvectordb_document = types.ModuleType("tcvectordb.model.document")
tcvectordb_index = types.ModuleType("tcvectordb.model.index")
tcvectordb_enum = types.ModuleType("tcvectordb.model.enum")
class _BM25Encoder:
def encode_texts(self, text):
return {"encoded_text": text}
def encode_queries(self, query):
return {"encoded_query": query}
@classmethod
def default(cls, _lang):
return cls()
class VectorDBError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message
class RPCVectorDBClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_database_if_not_exists = MagicMock()
self.exists_collection = MagicMock(return_value=False)
self.describe_collection = MagicMock(return_value=SimpleNamespace(indexes=[]))
self.create_collection = MagicMock()
self.upsert = MagicMock()
self.query = MagicMock(return_value=[])
self.delete = MagicMock()
self.search = MagicMock(return_value=[])
self.hybrid_search = MagicMock(return_value=[])
self.drop_collection = MagicMock()
class _Document:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class _HNSWSearchParams:
def __init__(self, ef):
self.ef = ef
class _AnnSearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _KeywordSearch:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _WeightedRerank:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _Filter:
@staticmethod
def in_(field, values):
return ("in", field, values)
def __init__(self, condition):
self.condition = condition
_Filter.In = staticmethod(_Filter.in_)
class _HNSWParams:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _FilterIndex:
def __init__(self, *args):
self.args = args
class _VectorIndex:
def __init__(self, *args):
self.args = args
class _SparseIndex:
def __init__(self, **kwargs):
self.kwargs = kwargs
tcvectordb_enum.IndexType = SimpleNamespace(
__members__={"HNSW": "HNSW", "PRIMARY_KEY": "PRIMARY_KEY", "FILTER": "FILTER", "SPARSE_INVERTED": "SPARSE"},
PRIMARY_KEY="PRIMARY_KEY",
FILTER="FILTER",
SPARSE_INVERTED="SPARSE",
)
tcvectordb_enum.MetricType = SimpleNamespace(__members__={"IP": "IP"}, IP="IP")
tcvectordb_enum.FieldType = SimpleNamespace(String="String", Json="Json", SparseVector="SparseVector")
tcvectordb_document.Document = _Document
tcvectordb_document.HNSWSearchParams = _HNSWSearchParams
tcvectordb_document.AnnSearch = _AnnSearch
tcvectordb_document.Filter = _Filter
tcvectordb_document.KeywordSearch = _KeywordSearch
tcvectordb_document.WeightedRerank = _WeightedRerank
tcvectordb_index.HNSWParams = _HNSWParams
tcvectordb_index.FilterIndex = _FilterIndex
tcvectordb_index.VectorIndex = _VectorIndex
tcvectordb_index.SparseIndex = _SparseIndex
tcvdb_text_encoder.BM25Encoder = _BM25Encoder
tcvectordb_model.document = tcvectordb_document
tcvectordb_model.enum = tcvectordb_enum
tcvectordb_model.index = tcvectordb_index
tcvectordb.RPCVectorDBClient = RPCVectorDBClient
tcvectordb.VectorDBException = VectorDBError
return {
"tcvdb_text": tcvdb_text,
"tcvdb_text.encoder": tcvdb_text_encoder,
"tcvectordb": tcvectordb,
"tcvectordb.model": tcvectordb_model,
"tcvectordb.model.document": tcvectordb_document,
"tcvectordb.model.index": tcvectordb_index,
"tcvectordb.model.enum": tcvectordb_enum,
}
@pytest.fixture
def tencent_module(monkeypatch):
for name, module in _build_fake_tencent_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.tencent.tencent_vector as module
return importlib.reload(module)
def _config(module, **overrides):
values = {
"url": "http://vdb.local",
"api_key": "api-key",
"timeout": 30,
"username": "user",
"database": "db",
"index_type": "HNSW",
"metric_type": "IP",
"shard": 1,
"replicas": 2,
"max_upsert_batch_size": 2,
"enable_hybrid_search": False,
}
values.update(overrides)
return module.TencentConfig.model_validate(values)
def test_config_and_init_paths(tencent_module):
config = _config(tencent_module)
assert config.to_tencent_params()["url"] == "http://vdb.local"
vector = tencent_module.TencentVector("collection_1", config)
assert vector.get_type() == tencent_module.VectorType.TENCENT
assert vector._client.kwargs["key"] == "api-key"
vector._client.exists_collection.return_value = True
vector._client.describe_collection.return_value = SimpleNamespace(
indexes=[SimpleNamespace(name="vector", dimension=768), SimpleNamespace(name="sparse_vector", dimension=0)]
)
vector._client_config.enable_hybrid_search = True
vector._load_collection()
assert vector._enable_hybrid_search is True
assert vector._dimension == 768
vector._client.describe_collection.return_value = SimpleNamespace(
indexes=[SimpleNamespace(name="vector", dimension=512)]
)
vector._load_collection()
assert vector._enable_hybrid_search is False
def test_create_collection_branches(tencent_module, monkeypatch):
vector = tencent_module.TencentVector("collection_1", _config(tencent_module))
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tencent_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tencent_module.redis_client, "set", MagicMock())
monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(3)
vector._client.create_collection.assert_not_called()
monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=None))
vector._client.exists_collection.return_value = True
vector._create_collection(3)
vector._client.create_collection.assert_not_called()
vector._client.exists_collection.return_value = False
vector._client_config.index_type = "UNKNOWN"
with pytest.raises(ValueError, match="unsupported index_type"):
vector._create_collection(3)
vector._client_config.index_type = "HNSW"
vector._client_config.metric_type = "UNKNOWN"
with pytest.raises(ValueError, match="unsupported metric_type"):
vector._create_collection(3)
vector._client_config.metric_type = "IP"
vector._client.create_collection.side_effect = [
tencent_module.VectorDBException("fieldType:json unsupported"),
None,
]
vector._enable_hybrid_search = True
vector._create_collection(3)
assert vector._client.create_collection.call_count == 2
tencent_module.redis_client.set.assert_called_once()
vector._client.create_collection.side_effect = None
def test_create_add_delete_and_search_behaviour(tencent_module):
vector = tencent_module.TencentVector("collection_1", _config(tencent_module, enable_hybrid_search=True))
vector._create_collection = MagicMock()
docs = [
Document(page_content="text-a", metadata={"doc_id": "a", "document_id": "doc-a"}),
Document(page_content="text-b", metadata={"doc_id": "b", "document_id": "doc-b"}),
Document(page_content="text-c", metadata={"doc_id": "c", "document_id": "doc-c"}),
]
embeddings = [[0.1], [0.2], [0.3]]
vector.create(docs, embeddings)
vector._create_collection.assert_called_once_with(1)
vector._client.upsert.reset_mock()
vector.add_texts(docs, embeddings)
assert vector._client.upsert.call_count == 2
first_docs = vector._client.upsert.call_args_list[0].kwargs["documents"]
assert "sparse_vector" in first_docs[0].__dict__
vector._client.query.return_value = [{"id": "a"}]
assert vector.text_exists("a") is True
vector._client.query.return_value = []
assert vector.text_exists("a") is False
vector.delete_by_ids([])
vector._client.delete.assert_not_called()
vector.delete_by_ids(["a", "b", "c"])
assert vector._client.delete.call_count == 2
vector.delete_by_metadata_field("document_id", "doc-a")
assert vector._client.delete.call_count >= 3
vector._client.search.return_value = [[{"metadata": {"doc_id": "1"}, "text": "vec-doc", "score": 0.9}]]
vec_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"])
assert len(vec_docs) == 1
assert vec_docs[0].metadata["score"] == pytest.approx(0.9)
vector._enable_hybrid_search = False
assert vector.search_by_full_text("query") == []
vector._enable_hybrid_search = True
vector._client.hybrid_search.return_value = [[{"metadata": {"doc_id": "2"}, "text": "fts-doc", "score": 0.8}]]
fts_docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"])
assert len(fts_docs) == 1
# _get_search_res handles old string metadata format
compat_docs = vector._get_search_res([[{"metadata": '{"doc_id": "3"}', "text": "compat", "score": 0.2}]], 0.5)
assert len(compat_docs) == 1
assert compat_docs[0].metadata["score"] == pytest.approx(0.8)
vector._has_collection = MagicMock(return_value=True)
vector.delete()
vector._client.drop_collection.assert_called_once()
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch):
factory = tencent_module.TencentVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(tencent_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_URL", "http://vdb.local")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_API_KEY", "api-key")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_TIMEOUT", 30)
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_USERNAME", "user")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_DATABASE", "db")
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_SHARD", 1)
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_REPLICAS", 2)
monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH", True)
with patch.object(tencent_module, "TencentVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,88 @@
from types import SimpleNamespace
import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
class _DummyVector(BaseVector):
def __init__(self, collection_name: str, existing_ids: set[str] | None = None):
super().__init__(collection_name)
self._existing_ids = existing_ids or set()
def get_type(self) -> str:
return "dummy"
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return None
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
return None
def text_exists(self, id: str) -> bool:
return id in self._existing_ids
def delete_by_ids(self, ids: list[str]):
return None
def delete_by_metadata_field(self, key: str, value: str):
return None
def search_by_vector(self, query_vector: list[float], **kwargs):
return []
def search_by_full_text(self, query: str, **kwargs):
return []
def delete(self):
return None
@pytest.mark.parametrize(
("base_method", "args"),
[
(BaseVector.get_type, ()),
(BaseVector.create, ([], [])),
(BaseVector.add_texts, ([], [])),
(BaseVector.text_exists, ("doc-1",)),
(BaseVector.delete_by_ids, ([],)),
(BaseVector.get_ids_by_metadata_field, ("doc_id", "doc-1")),
(BaseVector.delete_by_metadata_field, ("doc_id", "doc-1")),
(BaseVector.search_by_vector, ([0.1],)),
(BaseVector.search_by_full_text, ("query",)),
(BaseVector.delete, ()),
],
)
def test_base_vector_default_methods_raise_not_implemented(base_method, args):
vector = _DummyVector("collection_1")
with pytest.raises(NotImplementedError):
base_method(vector, *args)
def test_filter_duplicate_texts_removes_existing_docs():
vector = _DummyVector("collection_1", existing_ids={"dup"})
docs = [
SimpleNamespace(page_content="keep-no-meta", metadata=None),
Document(page_content="keep-no-doc-id", metadata={"document_id": "d1"}),
Document(page_content="remove-dup", metadata={"doc_id": "dup"}),
Document(page_content="keep-unique", metadata={"doc_id": "unique"}),
]
filtered = vector._filter_duplicate_texts(docs)
assert [d.page_content for d in filtered] == ["keep-no-meta", "keep-no-doc-id", "keep-unique"]
def test_get_uuids_and_collection_name_property():
vector = _DummyVector("collection_1")
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
SimpleNamespace(page_content="b", metadata=None),
Document(page_content="c", metadata={"document_id": "d-1"}),
Document(page_content="d", metadata={"doc_id": "id-2"}),
]
assert vector._get_uuids(docs) == ["id-1", "id-2"]
assert vector.collection_name == "collection_1"

View File

@ -0,0 +1,434 @@
import base64
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str):
fake_module = types.ModuleType(module_path)
fake_cls = type(class_name, (), {})
setattr(fake_module, class_name, fake_cls)
monkeypatch.setitem(sys.modules, module_path, fake_module)
return fake_cls
@pytest.fixture
def vector_factory_module():
import importlib
import core.rag.datasource.vdb.vector_factory as module
return importlib.reload(module)
def test_gen_index_struct_dict(vector_factory_module):
result = vector_factory_module.AbstractVectorFactory.gen_index_struct_dict(
vector_factory_module.VectorType.WEAVIATE,
"collection_1",
)
assert result == {
"type": vector_factory_module.VectorType.WEAVIATE,
"vector_store": {"class_prefix": "collection_1"},
}
@pytest.mark.parametrize(
("vector_type", "module_path", "class_name"),
[
("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"),
("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"),
(
"ALIBABACLOUD_MYSQL",
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector",
"AlibabaCloudMySQLVectorFactory",
),
("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"),
("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"),
("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"),
("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"),
("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"),
("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"),
(
"ELASTICSEARCH",
"core.rag.datasource.vdb.elasticsearch.elasticsearch_vector",
"ElasticSearchVectorFactory",
),
(
"ELASTICSEARCH_JA",
"core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector",
"ElasticSearchJaVectorFactory",
),
("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"),
("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"),
("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"),
("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"),
(
"OPENSEARCH",
"core.rag.datasource.vdb.opensearch.opensearch_vector",
"OpenSearchVectorFactory",
),
("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"),
("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"),
("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"),
("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"),
("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"),
(
"TIDB_ON_QDRANT",
"core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector",
"TidbOnQdrantVectorFactory",
),
("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"),
("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"),
("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"),
("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"),
("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"),
(
"HUAWEI_CLOUD",
"core.rag.datasource.vdb.huawei.huawei_cloud_vector",
"HuaweiCloudVectorFactory",
),
("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"),
("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"),
("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"),
],
)
def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name):
expected_cls = _register_fake_factory_module(monkeypatch, module_path, class_name)
result_cls = vector_factory_module.Vector.get_vector_factory(getattr(vector_factory_module.VectorType, vector_type))
assert result_cls is expected_cls
def test_get_vector_factory_unsupported(vector_factory_module):
with pytest.raises(ValueError, match="not supported"):
vector_factory_module.Vector.get_vector_factory("unknown")
def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
dataset = SimpleNamespace(id="dataset-1")
with (
patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"),
patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"),
):
default_vector = vector_factory_module.Vector(dataset)
custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"])
assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
assert custom_vector._attributes == ["doc_id"]
assert default_vector._embeddings == "embeddings"
assert default_vector._vector_processor == "processor"
def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch):
calls = {"vector_type": None, "init_args": None}
class _Factory:
def init_vector(self, dataset, attributes, embeddings):
calls["init_args"] = (dataset, attributes, embeddings)
return "vector-processor"
monkeypatch.setattr(
vector_factory_module.Vector,
"get_vector_factory",
staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory),
)
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._dataset = SimpleNamespace(
index_struct_dict={"type": vector_factory_module.VectorType.UPSTASH}, tenant_id="tenant-1"
)
vector._attributes = ["doc_id"]
vector._embeddings = "embeddings"
result = vector._init_vector()
assert result == "vector-processor"
assert calls["vector_type"] == vector_factory_module.VectorType.UPSTASH
assert calls["init_args"] == (vector._dataset, ["doc_id"], "embeddings")
def test_init_vector_uses_whitelist_override(vector_factory_module, monkeypatch):
class _Expr:
def __eq__(self, _other):
return "expr"
calls = {"vector_type": None}
class _Factory:
def init_vector(self, dataset, attributes, embeddings):
return "vector-processor"
monkeypatch.setattr(vector_factory_module, "Whitelist", SimpleNamespace(tenant_id=_Expr(), category=_Expr()))
monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
monkeypatch.setattr(
vector_factory_module,
"db",
SimpleNamespace(session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(one_or_none=lambda: object()))),
)
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", vector_factory_module.VectorType.CHROMA)
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", True)
monkeypatch.setattr(
vector_factory_module.Vector,
"get_vector_factory",
staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory),
)
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1")
vector._attributes = ["doc_id"]
vector._embeddings = "embeddings"
result = vector._init_vector()
assert result == "vector-processor"
assert calls["vector_type"] == vector_factory_module.VectorType.TIDB_ON_QDRANT
def test_init_vector_raises_when_vector_store_missing(vector_factory_module, monkeypatch):
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", None)
monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", False)
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1")
vector._attributes = []
vector._embeddings = "embeddings"
with pytest.raises(ValueError, match="Vector store must be specified"):
vector._init_vector()
def test_create_batches_texts_and_skips_empty_input(vector_factory_module):
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
docs = [Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(1001)]
vector._embeddings.embed_documents.side_effect = [
[[0.1] for _ in range(1000)],
[[0.2]],
]
vector.create(texts=docs, trace_id="trace-1")
assert vector._embeddings.embed_documents.call_count == 2
assert vector._vector_processor.create.call_count == 2
assert vector._vector_processor.create.call_args_list[0].kwargs["trace_id"] == "trace-1"
vector._embeddings.embed_documents.reset_mock()
vector._vector_processor.create.reset_mock()
vector.create(texts=None)
vector._embeddings.embed_documents.assert_not_called()
vector._vector_processor.create.assert_not_called()
def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch):
class _Field:
def in_(self, value):
return value
def __eq__(self, value):
return value
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._embeddings.embed_multimodal_documents.return_value = [[0.1, 0.2]]
vector._vector_processor = MagicMock()
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt"))
monkeypatch.setattr(
vector_factory_module,
"db",
SimpleNamespace(
session=SimpleNamespace(
scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="f-1", key="k-1")])
)
),
)
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"abc"))
docs = [
Document(page_content="file-1", metadata={"doc_id": "f-1", "doc_type": "image"}),
Document(page_content="file-2", metadata={"doc_id": "f-2", "doc_type": "image"}),
]
vector.create_multimodal(file_documents=docs, request_id="r-1")
file_base64 = base64.b64encode(b"abc").decode()
vector._embeddings.embed_multimodal_documents.assert_called_once_with(
[{"content": file_base64, "content_type": "image", "file_id": "f-1"}]
)
vector._vector_processor.create.assert_called_once_with(
texts=[docs[0]],
embeddings=[[0.1, 0.2]],
request_id="r-1",
)
vector._embeddings.embed_multimodal_documents.reset_mock()
vector._vector_processor.create.reset_mock()
vector.create_multimodal(file_documents=None)
vector._embeddings.embed_multimodal_documents.assert_not_called()
vector._vector_processor.create.assert_not_called()
def test_add_texts_with_optional_duplicate_check(vector_factory_module):
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
vector._filter_duplicate_texts = MagicMock()
docs = [
Document(page_content="a", metadata={"doc_id": "id-1"}),
Document(page_content="b", metadata={"doc_id": "id-2"}),
]
vector._filter_duplicate_texts.return_value = [docs[0]]
vector._embeddings.embed_documents.return_value = [[0.1]]
vector.add_texts(docs, duplicate_check=True, flag=True)
vector._filter_duplicate_texts.assert_called_once_with(docs)
vector._vector_processor.create.assert_called_once_with(
texts=[docs[0]], embeddings=[[0.1]], duplicate_check=True, flag=True
)
vector._filter_duplicate_texts.reset_mock()
vector._vector_processor.create.reset_mock()
vector._embeddings.embed_documents.return_value = [[0.2], [0.3]]
vector.add_texts(docs, duplicate_check=False)
vector._filter_duplicate_texts.assert_not_called()
vector._vector_processor.create.assert_called_once()
def test_vector_delegation_methods(vector_factory_module):
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._embeddings.embed_query.return_value = [0.1, 0.2]
vector._vector_processor = MagicMock()
vector._vector_processor.text_exists.return_value = True
vector._vector_processor.search_by_vector.return_value = ["vector-doc"]
vector._vector_processor.search_by_full_text.return_value = ["text-doc"]
assert vector.text_exists("doc-1") is True
vector.delete_by_ids(["doc-1"])
vector.delete_by_metadata_field("doc_id", "doc-1")
assert vector.search_by_vector("hello", top_k=3) == ["vector-doc"]
assert vector.search_by_full_text("hello", top_k=3) == ["text-doc"]
vector._vector_processor.delete_by_ids.assert_called_once_with(["doc-1"])
vector._vector_processor.delete_by_metadata_field.assert_called_once_with("doc_id", "doc-1")
def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch):
class _Field:
def __eq__(self, value):
return value
upload_query = MagicMock()
upload_query.where.return_value = upload_query
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr(
vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query))
)
upload_query.first.return_value = None
assert vector.search_by_file("file-1") == []
upload_query.first.return_value = SimpleNamespace(key="blob-key")
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes"))
vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4]
vector._vector_processor.search_by_vector.return_value = ["hit"]
result = vector.search_by_file("file-2", top_k=2)
assert result == ["hit"]
payload = vector._embeddings.embed_multimodal_query.call_args.args[0]
assert payload["content_type"] == vector_factory_module.DocType.IMAGE
assert payload["file_id"] == "file-2"
def test_delete_clears_redis_cache_when_collection_exists(vector_factory_module, monkeypatch):
delete_mock = MagicMock()
redis_delete = MagicMock()
monkeypatch.setattr(vector_factory_module.redis_client, "delete", redis_delete)
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="collection_1")
vector.delete()
delete_mock.assert_called_once()
redis_delete.assert_called_once_with("vector_indexing_collection_1")
vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="")
redis_delete.reset_mock()
vector.delete()
redis_delete.assert_not_called()
def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatch):
model_manager = MagicMock()
model_manager.get_model_instance.return_value = "model-instance"
monkeypatch.setattr(vector_factory_module, "ModelManager", MagicMock(return_value=model_manager))
monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding"))
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._dataset = SimpleNamespace(
tenant_id="tenant-1",
embedding_model_provider="openai",
embedding_model="text-embedding-3-small",
)
result = vector._get_embeddings()
assert result == "cached-embedding"
model_manager.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="openai",
model_type=vector_factory_module.ModelType.TEXT_EMBEDDING,
model="text-embedding-3-small",
)
def test_filter_duplicate_texts_and_getattr(vector_factory_module):
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector.text_exists = MagicMock(side_effect=lambda doc_id: doc_id == "dup")
docs = [
SimpleNamespace(page_content="no-meta", metadata=None),
Document(page_content="empty-doc-id", metadata={"doc_id": ""}),
Document(page_content="duplicate", metadata={"doc_id": "dup"}),
Document(page_content="unique", metadata={"doc_id": "ok"}),
]
filtered = vector._filter_duplicate_texts(docs)
assert [doc.page_content for doc in filtered] == ["no-meta", "empty-doc-id", "unique"]
class _Processor:
def ping(self):
return "pong"
vector._vector_processor = _Processor()
assert vector.ping() == "pong"
with pytest.raises(AttributeError):
_ = vector.unknown_method
vector._vector_processor = None
with pytest.raises(AttributeError, match="vector_processor"):
_ = vector.another_missing

View File

@ -0,0 +1,443 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
@pytest.fixture
def tidb_module():
import core.rag.datasource.vdb.tidb_vector.tidb_vector as module
return importlib.reload(module)
def _config(tidb_module):
return tidb_module.TiDBVectorConfig(
host="localhost",
port=4000,
user="root",
password="secret",
database="dify",
program_name="dify-app",
)
@pytest.mark.parametrize(
("field", "value", "message"),
[
("host", "", "config TIDB_VECTOR_HOST is required"),
("port", 0, "config TIDB_VECTOR_PORT is required"),
("user", "", "config TIDB_VECTOR_USER is required"),
("database", "", "config TIDB_VECTOR_DATABASE is required"),
("program_name", "", "config APPLICATION_NAME is required"),
],
)
def test_tidb_config_validation(tidb_module, field, value, message):
values = _config(tidb_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
tidb_module.TiDBVectorConfig.model_validate(values)
def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine"))
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2")
assert vector.get_type() == tidb_module.VectorType.TIDB_VECTOR
assert vector._url.startswith("mysql+pymysql://root:secret@localhost:4000/dify")
assert vector._dimension == 1536
assert vector._get_distance_func() == "VEC_L2_DISTANCE"
vector._distance_func = "cosine"
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
vector._distance_func = "other"
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch):
fake_tidb_vector = types.ModuleType("tidb_vector")
fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy")
class _VectorType:
def __init__(self, dim):
self.dim = dim
fake_tidb_sqlalchemy.VectorType = _VectorType
monkeypatch.setitem(sys.modules, "tidb_vector", fake_tidb_vector)
monkeypatch.setitem(sys.modules, "tidb_vector.sqlalchemy", fake_tidb_sqlalchemy)
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value=MagicMock()))
monkeypatch.setattr(tidb_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs))
monkeypatch.setattr(
tidb_module,
"Table",
lambda name, _metadata, *columns, **_kwargs: SimpleNamespace(name=name, columns=columns),
)
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module))
table = vector._table(3)
assert table.name == "collection_1"
column_names = [column.args[0] for column in table.columns]
assert tidb_module.Field.PRIMARY_KEY in column_names
assert tidb_module.Field.VECTOR in column_names
assert tidb_module.Field.TEXT_KEY in column_names
def test_create_calls_collection_and_add_texts(tidb_module):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="a", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
assert vector._dimension == 2
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=1))
monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock())
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
tidb_module.Session = MagicMock()
vector._create_collection(3)
tidb_module.Session.assert_not_called()
tidb_module.redis_client.set.assert_not_called()
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock())
session = MagicMock()
class _SessionCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector._distance_func = "l2"
vector._create_collection(3)
session.begin.assert_called_once()
sql = str(session.execute.call_args.args[0])
assert "VECTOR<FLOAT>(3)" in sql
assert "VEC_L2_DISTANCE" in sql
session.commit.assert_called_once()
tidb_module.redis_client.set.assert_called_once()
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
class _InsertStmt:
def __init__(self, table):
self.table = table
def values(self, rows):
return {"table": self.table, "rows": rows}
monkeypatch.setattr(tidb_module, "insert", lambda table: _InsertStmt(table))
conn = MagicMock()
transaction = MagicMock()
transaction.__enter__.return_value = None
transaction.__exit__.return_value = None
conn.begin.return_value = transaction
connection_ctx = MagicMock()
connection_ctx.__enter__.return_value = conn
connection_ctx.__exit__.return_value = None
engine = MagicMock()
engine.connect.return_value = connection_ctx
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._engine = engine
vector._table = MagicMock(return_value="table")
docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(501)]
embeddings = [[float(i)] for i in range(501)]
ids = vector.add_texts(docs, embeddings)
assert ids[0] == "id-0"
assert len(ids) == 501
assert conn.execute.call_count == 2
@pytest.fixture
def tidb_vector_with_session(tidb_module, monkeypatch):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
session = MagicMock()
class _SessionCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
return vector, session, tidb_module
# 1. search_by_full_text returns empty
def test_search_by_full_text_returns_empty(tidb_vector_with_session):
vector, _, _ = tidb_vector_with_session
assert vector.search_by_full_text("query") == []
# 2. text_exists returns True when ids found
def test_text_exists_returns_true_when_ids_found(tidb_vector_with_session):
vector, _, _ = tidb_vector_with_session
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
assert vector.text_exists("doc-1") is True
# 3. text_exists returns False when no ids
def test_text_exists_returns_false_when_no_ids(tidb_vector_with_session):
vector, _, _ = tidb_vector_with_session
vector.get_ids_by_metadata_field = MagicMock(return_value=None)
assert vector.text_exists("doc-1") is False
# 4. delete_by_ids delegates to _delete_by_ids when ids found
def test_delete_by_ids_delegates_to_internal_delete(tidb_vector_with_session):
vector, session, tidb_module = tidb_vector_with_session
session.execute.return_value.fetchall.return_value = [("id-a",), ("id-b",)]
vector._delete_by_ids = MagicMock()
# Use real get_ids_by_metadata_field
vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__(
vector, tidb_module.TiDBVector
)
vector.delete_by_ids(["doc-a", "doc-b"])
vector._delete_by_ids.assert_called_once_with(["id-a", "id-b"])
# 5. delete_by_ids skips when no ids found
def test_delete_by_ids_skips_when_no_ids_found(tidb_vector_with_session):
vector, session, tidb_module = tidb_vector_with_session
session.execute.return_value.fetchall.return_value = []
vector._delete_by_ids = MagicMock()
# Use real get_ids_by_metadata_field
vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__(
vector, tidb_module.TiDBVector
)
vector.delete_by_ids(["doc-c"])
vector._delete_by_ids.assert_not_called()
# 6. get_ids_by_metadata_field returns ids and returns None
def test_get_ids_by_metadata_field_returns_ids_and_returns_none(tidb_vector_with_session):
vector, session, tidb_module = tidb_vector_with_session
# Returns ids
session.execute.return_value.fetchall.return_value = [("id-1",)]
assert vector.get_ids_by_metadata_field("doc_id", "doc-1") == ["id-1"]
# Returns None
session.execute.return_value.fetchall.return_value = []
assert vector.get_ids_by_metadata_field("doc_id", "doc-1") is None
# 1. _delete_by_ids raises on None
def test__delete_by_ids_raises_on_none(tidb_module):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
with pytest.raises(ValueError, match="No ids provided"):
vector._delete_by_ids(None)
# 2. _delete_by_ids returns True and calls execute
def test__delete_by_ids_returns_true_and_calls_execute(tidb_module):
class _IDColumn:
def in_(self, ids):
return ids
class _Delete:
def where(self, condition):
return condition
table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete())
conn = MagicMock()
tx = MagicMock()
tx.__enter__.return_value = None
tx.__exit__.return_value = None
conn.begin.return_value = tx
conn_ctx = MagicMock()
conn_ctx.__enter__.return_value = conn
conn_ctx.__exit__.return_value = None
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._dimension = 2
vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx))
vector._table = MagicMock(return_value=table)
assert vector._delete_by_ids(["id-1"]) is True
conn.execute.assert_called_once()
# 3. _delete_by_ids returns False on RuntimeError
def test__delete_by_ids_returns_false_on_runtime_error(tidb_module):
class _IDColumn:
def in_(self, ids):
return ids
class _Delete:
def where(self, condition):
return condition
table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete())
conn = MagicMock()
tx = MagicMock()
tx.__enter__.return_value = None
tx.__exit__.return_value = None
conn.begin.return_value = tx
conn_ctx = MagicMock()
conn_ctx.__enter__.return_value = conn
conn_ctx.__exit__.return_value = None
conn.execute.side_effect = RuntimeError("delete failed")
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._dimension = 2
vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx))
vector._table = MagicMock(return_value=table)
assert vector._delete_by_ids(["id-2"]) is False
# 4. delete_by_metadata_field calls _delete_by_ids when ids found
def test_delete_by_metadata_field_calls__delete_by_ids_when_ids_found(tidb_module):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-3"])
vector._delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "doc-3")
vector._delete_by_ids.assert_called_once_with(["id-3"])
# 5. delete_by_metadata_field does nothing when no ids
def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector.get_ids_by_metadata_field = MagicMock(return_value=[])
vector._delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "doc-4")
vector._delete_by_ids.assert_not_called()
# Test search_by_vector filters and scores
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
session = MagicMock()
session.execute.return_value = [
('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2),
('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.4),
]
session.commit = MagicMock()
class _SessionCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector._distance_func = "cosine"
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=2,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)
assert len(docs) == 2
assert docs[0].metadata["score"] == pytest.approx(0.8)
assert docs[1].metadata["score"] == pytest.approx(0.6)
sql = str(session.execute.call_args.args[0])
params = session.execute.call_args.kwargs["params"]
assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql
assert params["distance"] == pytest.approx(0.5)
assert params["top_k"] == 2
session.commit.assert_not_called()
# Test delete drops table
def test_delete_drops_table(tidb_module, monkeypatch):
session = MagicMock()
session.execute.return_value = None
session.commit = MagicMock()
class _SessionCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector.delete()
drop_sql = str(session.execute.call_args.args[0])
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
session.commit.assert_called_once()
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
factory = tidb_module.TiDBVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(tidb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_HOST", "localhost")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PORT", 4000)
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify")
monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app")
with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,186 @@
import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.rag.models.document import Document
def _build_fake_upstash_module():
upstash_module = types.ModuleType("upstash_vector")
class Vector:
def __init__(self, id, vector, metadata, data):
self.id = id
self.vector = vector
self.metadata = metadata
self.data = data
class Index:
def __init__(self, url, token):
self.url = url
self.token = token
self.info = MagicMock(return_value=SimpleNamespace(dimension=8))
self.upsert = MagicMock()
self.query = MagicMock(return_value=[])
self.delete = MagicMock()
self.reset = MagicMock()
upstash_module.Vector = Vector
upstash_module.Index = Index
return upstash_module
@pytest.fixture
def upstash_module(monkeypatch):
# Remove patched modules if present
for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]:
if modname in sys.modules:
monkeypatch.delitem(sys.modules, modname, raising=False)
monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module())
module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector")
return module
def _config(module):
return module.UpstashVectorConfig(url="https://upstash.example", token="token-123")
@pytest.mark.parametrize(
("field", "value", "message"),
[
("url", "", "Upstash URL is required"),
("token", "", "Upstash Token is required"),
],
)
def test_upstash_config_validation(upstash_module, field, value, message):
values = _config(upstash_module).model_dump()
values[field] = value
with pytest.raises(ValidationError, match=message):
upstash_module.UpstashVectorConfig.model_validate(values)
def test_init_get_type_and_dimension(upstash_module, monkeypatch):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
assert vector.get_type() == upstash_module.VectorType.UPSTASH
assert vector._table_name == "collection_1"
assert vector._get_index_dimension() == 8
vector.index.info.return_value = SimpleNamespace(dimension=None)
assert vector._get_index_dimension() == 1536
vector.index.info.return_value = None
assert vector._get_index_dimension() == 1536
monkeypatch.setattr(upstash_module, "uuid4", lambda: "generated-uuid")
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
vector.add_texts(docs, [[0.1, 0.2]])
vector.index.upsert.assert_called_once()
upsert_vectors = vector.index.upsert.call_args.kwargs["vectors"]
assert upsert_vectors[0].id == "generated-uuid"
def test_create_text_exists_and_delete_by_ids(upstash_module):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1]])
vector.add_texts.assert_called_once_with(docs, [[0.1]])
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"])
assert vector.text_exists("doc-1") is True
vector.get_ids_by_metadata_field.return_value = []
assert vector.text_exists("doc-1") is False
vector.get_ids_by_metadata_field = MagicMock(side_effect=[["item-1"], [], ["item-2"]])
vector._delete_by_ids = MagicMock()
vector.delete_by_ids(["doc-1", "doc-2", "doc-3"])
vector._delete_by_ids.assert_called_once_with(ids=["item-1", "item-2"])
def test_delete_helpers_and_search(upstash_module):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
vector._delete_by_ids([])
vector.index.delete.assert_not_called()
vector._delete_by_ids(["a", "b"])
vector.index.delete.assert_called_once_with(ids=["a", "b"])
vector.index.query.return_value = [SimpleNamespace(id="x-1"), SimpleNamespace(id="x-2")]
ids = vector.get_ids_by_metadata_field("doc_id", "doc-1")
assert ids == ["x-1", "x-2"]
query_kwargs = vector.index.query.call_args.kwargs
assert query_kwargs["top_k"] == 1000
assert query_kwargs["filter"] == "doc_id = 'doc-1'"
vector._delete_by_ids = MagicMock()
vector.get_ids_by_metadata_field = MagicMock(return_value=["x-1"])
vector.delete_by_metadata_field("doc_id", "doc-1")
vector._delete_by_ids.assert_called_once_with(["x-1"])
vector._delete_by_ids.reset_mock()
vector.get_ids_by_metadata_field.return_value = []
vector.delete_by_metadata_field("doc_id", "doc-2")
vector._delete_by_ids.assert_not_called()
def test_search_by_vector_filter_threshold_and_delete(upstash_module):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
vector.index.query.return_value = [
SimpleNamespace(metadata={"document_id": "d-1"}, data="text-1", score=0.9),
SimpleNamespace(metadata={"document_id": "d-2"}, data="text-2", score=0.3),
SimpleNamespace(metadata=None, data="text-3", score=0.99),
SimpleNamespace(metadata={"document_id": "d-4"}, data=None, score=0.99),
]
docs = vector.search_by_vector(
[0.1, 0.2],
top_k=3,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)
assert len(docs) == 1
assert docs[0].page_content == "text-1"
assert docs[0].metadata["score"] == pytest.approx(0.9)
search_kwargs = vector.index.query.call_args.kwargs
assert search_kwargs["top_k"] == 3
assert search_kwargs["filter"] == "document_id in ('d-1', 'd-2')"
assert vector.search_by_full_text("query") == []
vector.delete()
vector.index.reset.assert_called_once()
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch):
factory = upstash_module.UpstashVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(upstash_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_URL", "https://upstash.example")
monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_TOKEN", "token-123")
with patch.object(upstash_module, "UpstashVector", return_value="vector") as vector_cls:
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None

View File

@ -0,0 +1,310 @@
import importlib
import json
import sys
import types
from collections import UserDict
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.rag.models.document import Document
def _build_fake_vikingdb_modules():
volcengine = types.ModuleType("volcengine")
volcengine.__path__ = []
viking_db = types.ModuleType("volcengine.viking_db")
class Data(UserDict):
def __init__(self, payload):
super().__init__(payload)
self.fields = payload
class DistanceType:
L2 = "L2"
class IndexType:
HNSW = "HNSW"
class QuantType:
Float = "Float"
class FieldType:
String = "string"
Text = "text"
Vector = "vector"
class Field:
def __init__(self, **kwargs):
self.kwargs = kwargs
class VectorIndexParams:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _Collection:
def __init__(self):
self.upsert_data = MagicMock()
self.fetch_data = MagicMock(return_value=None)
self.delete_data = MagicMock()
class _Index:
def __init__(self):
self.search = MagicMock(return_value=[])
self.search_by_vector = MagicMock(return_value=[])
class VikingDBService:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_collection = MagicMock()
self.create_index = MagicMock()
self.drop_index = MagicMock()
self.drop_collection = MagicMock()
self._collection = _Collection()
self._index = _Index()
self.get_collection = MagicMock(return_value=self._collection)
self.get_index = MagicMock(return_value=self._index)
viking_db.Data = Data
viking_db.DistanceType = DistanceType
viking_db.Field = Field
viking_db.FieldType = FieldType
viking_db.IndexType = IndexType
viking_db.QuantType = QuantType
viking_db.VectorIndexParams = VectorIndexParams
viking_db.VikingDBService = VikingDBService
return {"volcengine": volcengine, "volcengine.viking_db": viking_db}
@pytest.fixture
def vikingdb_module(monkeypatch):
for name, module in _build_fake_vikingdb_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module
return importlib.reload(module)
def _config(module):
return module.VikingDBConfig(
access_key="ak",
secret_key="sk",
host="host",
region="region",
scheme="https",
connection_timeout=10,
socket_timeout=20,
)
def test_init_get_type_and_has_checks(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
assert vector.get_type() == vikingdb_module.VectorType.VIKINGDB
assert vector._index_name == "collection_1_idx"
assert vector._has_collection() is True
assert vector._has_index() is True
vector._client.get_collection.side_effect = RuntimeError("missing")
assert vector._has_collection() is False
vector._client.get_collection.side_effect = None
vector._client.get_index.side_effect = RuntimeError("missing")
assert vector._has_index() is False
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(vikingdb_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(vikingdb_module.redis_client, "set", MagicMock())
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=1))
vector._create_collection(3)
vector._client.create_collection.assert_not_called()
vector._client.create_index.assert_not_called()
monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=None))
vector._has_collection = MagicMock(return_value=False)
vector._has_index = MagicMock(return_value=False)
vector._create_collection(4)
vector._client.create_collection.assert_called_once()
vector._client.create_index.assert_called_once()
vikingdb_module.redis_client.set.assert_called_once()
def test_create_and_add_texts(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
vector._create_collection = MagicMock()
vector.add_texts = MagicMock()
docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})]
vector.create(docs, [[0.1, 0.2]])
vector._create_collection.assert_called_once_with(2)
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
vector = vikingdb_module.VikingDBVector("collection_2", "group-2", _config(vikingdb_module))
docs = [
Document(page_content="a", metadata={"doc_id": "id-a", "document_id": "d-1"}),
Document(page_content="b", metadata={"doc_id": "id-b", "document_id": "d-2"}),
]
vector.add_texts(docs, [[0.1], [0.2]])
vector._client.get_collection.assert_called()
upsert_docs = vector._client.get_collection.return_value.upsert_data.call_args.args[0]
assert upsert_docs[0][vikingdb_module.vdb_Field.PRIMARY_KEY] == "id-a"
assert upsert_docs[0][vikingdb_module.vdb_Field.GROUP_KEY] == "group-2"
def test_text_exists_and_delete_operations(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(fields={"message": "ok"})
assert vector.text_exists("id-1") is True
vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(
fields={"message": "data does not exist"}
)
assert vector.text_exists("id-1") is False
vector._client.get_collection.return_value.fetch_data.return_value = None
assert vector.text_exists("id-1") is False
vector.delete_by_ids(["id-1"])
vector._client.get_collection.return_value.delete_data.assert_called_once_with(["id-1"])
vector.get_ids_by_metadata_field = MagicMock(return_value=["id-2"])
vector.delete_by_ids = MagicMock()
vector.delete_by_metadata_field("doc_id", "doc-1")
vector.delete_by_ids.assert_called_once_with(["id-2"])
def test_get_ids_and_search_helpers(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
vector._client.get_index.return_value.search.return_value = []
assert vector.get_ids_by_metadata_field("doc_id", "x") == []
vector._client.get_index.return_value.search.return_value = [
SimpleNamespace(id="a", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "x"})}),
SimpleNamespace(id="b", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "y"})}),
SimpleNamespace(id="c", fields={}),
]
assert vector.get_ids_by_metadata_field("doc_id", "x") == ["a"]
empty_docs = vector._get_search_res([], score_threshold=0.1)
assert empty_docs == []
results = [
SimpleNamespace(
id="a",
score=0.3,
fields={
vikingdb_module.vdb_Field.CONTENT_KEY: "doc-a",
vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-1"}),
},
),
SimpleNamespace(
id="b",
score=0.9,
fields={
vikingdb_module.vdb_Field.CONTENT_KEY: "doc-b",
vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-2"}),
},
),
]
docs = vector._get_search_res(results, score_threshold=0.2)
assert [doc.page_content for doc in docs] == ["doc-b", "doc-a"]
vector._client.get_index.return_value.search_by_vector.return_value = results
filtered_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.2, document_ids_filter=["d-2"])
assert len(filtered_docs) == 1
assert filtered_docs[0].page_content == "doc-b"
assert vector.search_by_full_text("query") == []
def test_delete_drops_index_and_collection_when_present(vikingdb_module):
vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module))
vector._has_index = MagicMock(return_value=True)
vector._has_collection = MagicMock(return_value=True)
vector.delete()
vector._client.drop_index.assert_called_once_with("collection_1", "collection_1_idx")
vector._client.drop_collection.assert_called_once_with("collection_1")
vector._client.drop_index.reset_mock()
vector._client.drop_collection.reset_mock()
vector._has_index.return_value = False
vector._has_collection.return_value = False
vector.delete()
vector._client.drop_index.assert_not_called()
vector._client.drop_collection.assert_not_called()
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch):
factory = vikingdb_module.VikingDBVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}},
index_struct=None,
)
dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(vikingdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
with patch.object(vikingdb_module, "VikingDBVector", return_value="vector") as vector_cls:
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_CONNECTION_TIMEOUT", 10)
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SOCKET_TIMEOUT", 20)
result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock())
result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock())
assert result_1 == "vector"
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert dataset_without_index.index_struct is not None
@pytest.mark.parametrize(
("field", "message"),
[
("VIKINGDB_ACCESS_KEY", "VIKINGDB_ACCESS_KEY should not be None"),
("VIKINGDB_SECRET_KEY", "VIKINGDB_SECRET_KEY should not be None"),
("VIKINGDB_HOST", "VIKINGDB_HOST should not be None"),
("VIKINGDB_REGION", "VIKINGDB_REGION should not be None"),
("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"),
],
)
def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message):
factory = vikingdb_module.VikingDBVectorFactory()
dataset = SimpleNamespace(
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None
)
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region")
monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https")
monkeypatch.setattr(vikingdb_module.dify_config, field, None)
with pytest.raises(ValueError, match=message):
factory.init_vector(dataset, attributes=[], embeddings=MagicMock())

View File

@ -7,10 +7,14 @@ Focuses on verifying that doc_type is properly handled in:
- Full-text search result metadata (search_by_full_text) - Full-text search result metadata (search_by_full_text)
""" """
import datetime
import json
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
from core.rag.models.document import Document from core.rag.models.document import Document
@ -32,6 +36,10 @@ class TestWeaviateVector(unittest.TestCase):
def tearDown(self): def tearDown(self):
weaviate_vector_module._weaviate_client = None weaviate_vector_module._weaviate_client = None
def test_config_requires_endpoint(self):
with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"):
WeaviateConfig(endpoint="")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def _create_weaviate_vector(self, mock_weaviate_module): def _create_weaviate_vector(self, mock_weaviate_module):
"""Helper to create a WeaviateVector instance with mocked client.""" """Helper to create a WeaviateVector instance with mocked client."""
@ -46,6 +54,85 @@ class TestWeaviateVector(unittest.TestCase):
) )
return wv, mock_client return wv, mock_client
def test_shutdown_client_logs_debug_when_close_fails(self):
mock_client = MagicMock()
mock_client.close.side_effect = RuntimeError("close failed")
weaviate_vector_module._weaviate_client = mock_client
with patch.object(weaviate_vector_module.logger, "debug") as mock_debug:
weaviate_vector_module._shutdown_weaviate_client()
assert weaviate_vector_module._weaviate_client is None
mock_client.close.assert_called_once()
mock_debug.assert_called_once()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect):
cached_client = MagicMock()
cached_client.is_ready.return_value = True
weaviate_vector_module._weaviate_client = cached_client
wv = WeaviateVector.__new__(WeaviateVector)
client = wv._init_client(self.config)
assert client is cached_client
mock_connect.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect):
cached_client = MagicMock()
cached_client.is_ready.side_effect = [False, True]
weaviate_vector_module._weaviate_client = cached_client
wv = WeaviateVector.__new__(WeaviateVector)
client = wv._init_client(self.config)
assert client is cached_client
mock_connect.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_connect.return_value = mock_client
wv = WeaviateVector.__new__(WeaviateVector)
config = WeaviateConfig(
endpoint="https://weaviate.example.com",
grpc_endpoint="grpc.example.com:6000",
api_key="test-key",
batch_size=50,
)
client = wv._init_client(config)
assert client is mock_client
assert mock_connect.call_args.kwargs == {
"http_host": "weaviate.example.com",
"http_port": 443,
"http_secure": True,
"grpc_host": "grpc.example.com",
"grpc_port": 6000,
"grpc_secure": False,
"auth_credentials": "auth-token",
"skip_init_checks": True,
}
mock_api_key.assert_called_once_with("test-key")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom")
def test_init_client_raises_when_database_not_ready(self, mock_connect):
mock_client = MagicMock()
mock_client.is_ready.return_value = False
mock_connect.return_value = mock_client
wv = WeaviateVector.__new__(WeaviateVector)
with pytest.raises(ConnectionError, match="Vector database is not ready"):
wv._init_client(self.config)
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_init(self, mock_weaviate_module): def test_init(self, mock_weaviate_module):
"""Test WeaviateVector initialization stores attributes including doc_type.""" """Test WeaviateVector initialization stores attributes including doc_type."""
@ -62,6 +149,40 @@ class TestWeaviateVector(unittest.TestCase):
assert wv._collection_name == self.collection_name assert wv._collection_name == self.collection_name
assert "doc_type" in wv._attributes assert "doc_type" in wv._attributes
def test_get_type_and_to_index_struct(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
assert wv.get_type() == weaviate_vector_module.VectorType.WEAVIATE
assert wv.to_index_struct() == {
"type": weaviate_vector_module.VectorType.WEAVIATE,
"vector_store": {"class_prefix": self.collection_name},
}
def test_get_collection_name_uses_existing_class_prefix_and_appends_suffix(self):
dataset = SimpleNamespace(index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection"}}, id="ds-1")
wv = WeaviateVector.__new__(WeaviateVector)
assert wv.get_collection_name(dataset) == "ExistingCollection_Node"
def test_get_collection_name_generates_name_from_dataset_id(self):
dataset = SimpleNamespace(index_struct_dict=None, id="ds-2")
wv = WeaviateVector.__new__(WeaviateVector)
with patch.object(weaviate_vector_module.Dataset, "gen_collection_name_by_id", return_value="Generated_Node"):
assert wv.get_collection_name(dataset) == "Generated_Node"
def test_create_calls_collection_setup_then_add_texts(self):
doc = Document(page_content="hello", metadata={})
wv = WeaviateVector.__new__(WeaviateVector)
wv._create_collection = MagicMock()
wv.add_texts = MagicMock()
wv.create([doc], [[0.1, 0.2]])
wv._create_collection.assert_called_once()
wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]])
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config")
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
@ -111,6 +232,44 @@ class TestWeaviateVector(unittest.TestCase):
f"doc_type should be in collection schema properties, got: {property_names}" f"doc_type should be in collection schema properties, got: {property_names}"
) )
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis):
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = 1
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._ensure_properties = MagicMock()
wv._create_collection()
wv._client.collections.exists.assert_not_called()
wv._ensure_properties.assert_not_called()
mock_redis.set.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client")
def test_create_collection_logs_and_reraises_errors(self, mock_redis):
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock(return_value=False)
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.side_effect = RuntimeError("create failed")
with patch.object(weaviate_vector_module.logger, "exception") as mock_exception:
with pytest.raises(RuntimeError, match="create failed"):
wv._create_collection()
mock_exception.assert_called_once()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module):
"""Test that _ensure_properties adds doc_type when it's missing from existing schema.""" """Test that _ensure_properties adds doc_type when it's missing from existing schema."""
@ -146,6 +305,29 @@ class TestWeaviateVector(unittest.TestCase):
added_names = [call.args[0].name for call in add_calls] added_names = [call.args[0].name for call in add_calls]
assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_cfg = MagicMock()
mock_cfg.properties = [SimpleNamespace(name="text")]
mock_col.config.get.return_value = mock_cfg
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
wv._ensure_properties()
add_calls = mock_col.config.add_property.call_args_list
added_names = [call.args[0].name for call in add_calls]
assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"]
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module):
"""Test that _ensure_properties does not add doc_type when it already exists.""" """Test that _ensure_properties does not add doc_type when it already exists."""
@ -179,6 +361,30 @@ class TestWeaviateVector(unittest.TestCase):
# No properties should be added # No properties should be added
mock_col.config.add_property.assert_not_called() mock_col.config.add_property.assert_not_called()
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_cfg = MagicMock()
mock_cfg.properties = []
mock_col.config.get.return_value = mock_cfg
mock_col.config.add_property.side_effect = RuntimeError("cannot add")
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
with patch.object(weaviate_vector_module.logger, "warning") as mock_warning:
wv._ensure_properties()
assert mock_warning.call_count == 4
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module):
"""Test that search_by_vector returns doc_type in document metadata. """Test that search_by_vector returns doc_type in document metadata.
@ -226,6 +432,58 @@ class TestWeaviateVector(unittest.TestCase):
assert len(docs) == 1 assert len(docs) == 1
assert docs[0].metadata.get("doc_type") == "image" assert docs[0].metadata.get("doc_type") == "image"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_obj = MagicMock()
mock_obj.properties = {
"text": "fallback distance result",
"document_id": "doc-1",
"doc_id": "segment-1",
}
mock_obj.metadata = None
mock_result = MagicMock()
mock_result.objects = [mock_obj]
mock_col.query.near_vector.return_value = mock_result
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
docs = wv.search_by_vector(
query_vector=[0.2] * 3,
document_ids_filter=["doc-1"],
top_k=2,
score_threshold=-1,
)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.0
assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = False
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
assert wv.search_by_vector(query_vector=[0.1] * 3) == []
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module):
"""Test that search_by_full_text also returns doc_type in document metadata.""" """Test that search_by_full_text also returns doc_type in document metadata."""
@ -268,6 +526,49 @@ class TestWeaviateVector(unittest.TestCase):
assert len(docs) == 1 assert len(docs) == 1
assert docs[0].metadata.get("doc_type") == "image" assert docs[0].metadata.get("doc_type") == "image"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = True
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_obj = MagicMock()
mock_obj.properties = {"text": "bm25 result", "doc_id": "segment-1"}
mock_obj.vector = [0.3, 0.4]
mock_result = MagicMock()
mock_result.objects = [mock_obj]
mock_col.query.bm25.return_value = mock_result
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
docs = wv.search_by_full_text(query="bm25", document_ids_filter=["doc-1"])
assert len(docs) == 1
assert docs[0].vector == [0.3, 0.4]
assert mock_col.query.bm25.call_args.kwargs["filters"] is not None
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_client.collections.exists.return_value = False
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
assert wv.search_by_full_text(query="missing") == []
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module):
"""Test that add_texts includes doc_type from document metadata in stored properties.""" """Test that add_texts includes doc_type from document metadata in stored properties."""
@ -310,6 +611,135 @@ class TestWeaviateVector(unittest.TestCase):
stored_props = call_kwargs.kwargs.get("properties") stored_props = call_kwargs.kwargs.get("properties")
assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}"
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module):
mock_client = MagicMock()
mock_client.is_ready.return_value = True
mock_weaviate_module.connect_to_custom.return_value = mock_client
mock_col = MagicMock()
mock_client.collections.use.return_value = mock_col
mock_batch = MagicMock()
mock_batch.__enter__ = MagicMock(return_value=mock_batch)
mock_batch.__exit__ = MagicMock(return_value=False)
mock_col.batch.dynamic.return_value = mock_batch
created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC)
doc = Document(page_content="text", metadata={"created_at": created_at})
wv = WeaviateVector(
collection_name=self.collection_name,
config=self.config,
attributes=self.attributes,
)
with (
patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]),
patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"),
):
ids = wv.add_texts(documents=[doc], embeddings=[[]])
assert ids == ["fallback-uuid"]
call_kwargs = mock_batch.add_object.call_args
assert call_kwargs.kwargs["uuid"] == "fallback-uuid"
assert call_kwargs.kwargs["vector"] is None
assert call_kwargs.kwargs["properties"]["created_at"] == created_at.isoformat()
def test_is_uuid_handles_invalid_values(self):
wv = WeaviateVector.__new__(WeaviateVector)
assert wv._is_uuid("123e4567-e89b-12d3-a456-426614174000") is True
assert wv._is_uuid("not-a-uuid") is False
def test_delete_by_metadata_field_returns_when_collection_is_missing(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.return_value = False
wv.delete_by_metadata_field("doc_id", "segment-1")
wv._client.collections.use.assert_not_called()
def test_delete_by_metadata_field_deletes_matching_objects(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.return_value = True
mock_col = MagicMock()
wv._client.collections.use.return_value = mock_col
wv.delete_by_metadata_field("doc_id", "segment-1")
mock_col.data.delete_many.assert_called_once()
def test_delete_removes_collection_when_present(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.return_value = True
wv.delete()
wv._client.collections.delete.assert_called_once_with(self.collection_name)
def test_text_exists_handles_missing_and_present_documents(self):
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.side_effect = [False, True]
mock_col = MagicMock()
wv._client.collections.use.return_value = mock_col
mock_col.query.fetch_objects.return_value = SimpleNamespace(objects=[SimpleNamespace()])
assert wv.text_exists("segment-1") is False
assert wv.text_exists("segment-1") is True
def test_delete_by_ids_handles_missing_collections_and_404s(self):
class FakeUnexpectedStatusCodeError(Exception):
def __init__(self, status_code):
super().__init__(f"status={status_code}")
self.status_code = status_code
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.side_effect = [False, True]
mock_col = MagicMock()
wv._client.collections.use.return_value = mock_col
mock_col.data.delete_by_id.side_effect = [FakeUnexpectedStatusCodeError(404), None]
with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError):
wv.delete_by_ids(["ignored"])
wv.delete_by_ids(["missing-id", "ok-id"])
assert mock_col.data.delete_by_id.call_count == 2
def test_delete_by_ids_reraises_non_404_errors(self):
class FakeUnexpectedStatusCodeError(Exception):
def __init__(self, status_code):
super().__init__(f"status={status_code}")
self.status_code = status_code
wv = WeaviateVector.__new__(WeaviateVector)
wv._collection_name = self.collection_name
wv._client = MagicMock()
wv._client.collections.exists.return_value = True
mock_col = MagicMock()
wv._client.collections.use.return_value = mock_col
mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500)
with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError):
with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"):
wv.delete_by_ids(["bad-id"])
def test_json_serializable_converts_datetime(self):
wv = WeaviateVector.__new__(WeaviateVector)
created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC)
assert wv._json_serializable(created_at) == created_at.isoformat()
assert wv._json_serializable("plain") == "plain"
class TestVectorDefaultAttributes(unittest.TestCase): class TestVectorDefaultAttributes(unittest.TestCase):
"""Tests for Vector class default attributes list.""" """Tests for Vector class default attributes list."""
@ -331,5 +761,65 @@ class TestVectorDefaultAttributes(unittest.TestCase):
assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}" assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}"
class TestWeaviateVectorFactory(unittest.TestCase):
def test_init_vector_uses_existing_dataset_index_struct(self):
dataset = SimpleNamespace(
id="dataset-1",
index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection_Node"}},
index_struct=None,
)
attributes = ["doc_id"]
with (
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88),
patch(
"core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector"
) as mock_vector,
):
factory = weaviate_vector_module.WeaviateVectorFactory()
result = factory.init_vector(dataset, attributes, MagicMock())
assert result == "vector"
config = mock_vector.call_args.kwargs["config"]
assert mock_vector.call_args.kwargs["collection_name"] == "ExistingCollection_Node"
assert mock_vector.call_args.kwargs["attributes"] == attributes
assert config.endpoint == "http://localhost:8080"
assert config.grpc_endpoint == "localhost:50051"
assert config.api_key == "api-key"
assert config.batch_size == 88
assert dataset.index_struct is None
def test_init_vector_generates_collection_and_updates_index_struct(self):
dataset = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None)
attributes = ["doc_id", "doc_type"]
with (
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", ""),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", None),
patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 100),
patch.object(
weaviate_vector_module.Dataset,
"gen_collection_name_by_id",
return_value="GeneratedCollection_Node",
),
patch(
"core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector"
) as mock_vector,
):
factory = weaviate_vector_module.WeaviateVectorFactory()
result = factory.init_vector(dataset, attributes, MagicMock())
assert result == "vector"
assert mock_vector.call_args.kwargs["collection_name"] == "GeneratedCollection_Node"
assert json.loads(dataset.index_struct) == {
"type": weaviate_vector_module.VectorType.WEAVIATE,
"vector_store": {"class_prefix": "GeneratedCollection_Node"},
}
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1164,7 +1164,7 @@ class TestConversationStatusCount:
conversation.id = str(uuid4()) conversation.id = str(uuid4())
# Mock the database query to return no messages # Mock the database query to return no messages
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = [] mock_scalars.return_value.all.return_value = []
# Act # Act
@ -1189,7 +1189,7 @@ class TestConversationStatusCount:
conversation.id = conversation_id conversation.id = conversation_id
# Mock the database query to return no messages with workflow_run_id # Mock the database query to return no messages with workflow_run_id
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = [] mock_scalars.return_value.all.return_value = []
# Act # Act
@ -1274,7 +1274,7 @@ class TestConversationStatusCount:
return mock_result return mock_result
# Act & Assert # Act & Assert
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True): with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count result = conversation.status_count
# Verify only 2 database queries were made (not N+1) # Verify only 2 database queries were made (not N+1)
@ -1337,7 +1337,7 @@ class TestConversationStatusCount:
return mock_result return mock_result
# Act # Act
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True): with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count result = conversation.status_count
# Assert - query should include app_id filter # Assert - query should include app_id filter
@ -1382,7 +1382,7 @@ class TestConversationStatusCount:
), ),
] ]
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: with patch("models.model.db.session.scalars") as mock_scalars:
# Mock the messages query # Mock the messages query
def mock_scalars_side_effect(query): def mock_scalars_side_effect(query):
mock_result = MagicMock() mock_result = MagicMock()
@ -1438,7 +1438,7 @@ class TestConversationStatusCount:
), ),
] ]
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: with patch("models.model.db.session.scalars") as mock_scalars:
def mock_scalars_side_effect(query): def mock_scalars_side_effect(query):
mock_result = MagicMock() mock_result = MagicMock()

View File

@ -13,6 +13,10 @@ import pytest
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
def _oauth_proxy_setex_calls(redis_client) -> list:
return [call for call in redis_client.setex.call_args_list if call.args[0].startswith("oauth_proxy_context:")]
class TestCreateProxyContext: class TestCreateProxyContext:
def test_stores_context_in_redis_with_ttl(self): def test_stores_context_in_redis_with_ttl(self):
context_id = OAuthProxyService.create_proxy_context( context_id = OAuthProxyService.create_proxy_context(
@ -22,8 +26,9 @@ class TestCreateProxyContext:
assert context_id # non-empty UUID string assert context_id # non-empty UUID string
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
redis_client.setex.assert_called_once() oauth_calls = _oauth_proxy_setex_calls(redis_client)
call_args = redis_client.setex.call_args assert len(oauth_calls) == 1
call_args = oauth_calls[0]
key = call_args[0][0] key = call_args[0][0]
ttl = call_args[0][1] ttl = call_args[0][1]
stored_data = json.loads(call_args[0][2]) stored_data = json.loads(call_args[0][2])

View File

@ -211,6 +211,7 @@ def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch
def test_import_app_pending_stores_import_info_in_redis(): def test_import_app_pending_stores_import_info_in_redis():
service = AppDslService(MagicMock()) service = AppDslService(MagicMock())
app_dsl_service.redis_client.setex.reset_mock()
result = service.import_app( result = service.import_app(
account=_account_mock(), account=_account_mock(),
import_mode=ImportMode.YAML_CONTENT, import_mode=ImportMode.YAML_CONTENT,
@ -375,10 +376,13 @@ def test_confirm_import_success_deletes_redis_key(monkeypatch):
created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1")
monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app)
app_dsl_service.redis_client.delete.reset_mock()
result = service.confirm_import(import_id="import-1", account=_account_mock()) result = service.confirm_import(import_id="import-1", account=_account_mock())
assert result.status == ImportStatus.COMPLETED assert result.status == ImportStatus.COMPLETED
assert result.app_id == "confirmed-app" assert result.app_id == "confirmed-app"
app_dsl_service.redis_client.delete.assert_called_once() app_dsl_service.redis_client.delete.assert_called_once_with(
f"{app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX}import-1"
)
def test_confirm_import_exception_returns_failed(monkeypatch): def test_confirm_import_exception_returns_failed(monkeypatch):

View File

@ -405,7 +405,7 @@ class TestAudioServiceTTS:
voice="en-US-Neural", voice="en-US-Neural",
) )
@patch("services.audio_service.db.session", autospec=True) @patch("services.audio_service.db.session")
@patch("services.audio_service.ModelManager", autospec=True) @patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory): def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
"""Test successful TTS with message ID.""" """Test successful TTS with message ID."""
@ -549,7 +549,7 @@ class TestAudioServiceTTS:
with pytest.raises(ValueError, match="Text is required"): with pytest.raises(ValueError, match="Text is required"):
AudioService.transcript_tts(app_model=app, text=None) AudioService.transcript_tts(app_model=app, text=None)
@patch("services.audio_service.db.session", autospec=True) @patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory): def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
"""Test that TTS returns None for invalid message ID format.""" """Test that TTS returns None for invalid message ID format."""
# Arrange # Arrange
@ -564,7 +564,7 @@ class TestAudioServiceTTS:
# Assert # Assert
assert result is None assert result is None
@patch("services.audio_service.db.session", autospec=True) @patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory): def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
"""Test that TTS returns None when message doesn't exist.""" """Test that TTS returns None when message doesn't exist."""
# Arrange # Arrange
@ -585,7 +585,7 @@ class TestAudioServiceTTS:
# Assert # Assert
assert result is None assert result is None
@patch("services.audio_service.db.session", autospec=True) @patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory): def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
"""Test that TTS returns None when message answer is empty.""" """Test that TTS returns None when message answer is empty."""
# Arrange # Arrange

View File

@ -313,7 +313,8 @@ class TestEmailDeliveryTestHandler:
recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")],
) )
subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") with patch.object(dify_config, "APP_WEB_URL", "http://example.com"):
subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com")
assert subs["node_title"] == "title" assert subs["node_title"] == "title"
assert subs["form_content"] == "content" assert subs["form_content"] == "content"

View File

@ -316,7 +316,7 @@ class TestTagServiceRetrieval:
- get_tags_by_target_id: Get all tags bound to a specific target - get_tags_by_target_id: Get all tags bound to a specific target
""" """
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_get_tags_with_binding_counts(self, mock_db_session, factory): def test_get_tags_with_binding_counts(self, mock_db_session, factory):
""" """
Test retrieving tags with their binding counts. Test retrieving tags with their binding counts.
@ -373,7 +373,7 @@ class TestTagServiceRetrieval:
# Verify database query was called # Verify database query was called
mock_db_session.query.assert_called_once() mock_db_session.query.assert_called_once()
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_get_tags_with_keyword_filter(self, mock_db_session, factory): def test_get_tags_with_keyword_filter(self, mock_db_session, factory):
""" """
Test retrieving tags filtered by keyword (case-insensitive). Test retrieving tags filtered by keyword (case-insensitive).
@ -427,7 +427,7 @@ class TestTagServiceRetrieval:
# 2. Additional WHERE clause for keyword filtering # 2. Additional WHERE clause for keyword filtering
assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause"
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): def test_get_target_ids_by_tag_ids(self, mock_db_session, factory):
""" """
Test retrieving target IDs by tag IDs. Test retrieving target IDs by tag IDs.
@ -483,7 +483,7 @@ class TestTagServiceRetrieval:
# Verify both queries were executed # Verify both queries were executed
assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query"
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory):
""" """
Test that empty tag_ids returns empty list. Test that empty tag_ids returns empty list.
@ -511,7 +511,7 @@ class TestTagServiceRetrieval:
assert results == [], "Should return empty list for empty input" assert results == [], "Should return empty list for empty input"
mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" mock_db_session.scalars.assert_not_called(), "Should not query database for empty input"
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_get_tag_by_tag_name(self, mock_db_session, factory): def test_get_tag_by_tag_name(self, mock_db_session, factory):
""" """
Test retrieving tags by name. Test retrieving tags by name.
@ -553,7 +553,7 @@ class TestTagServiceRetrieval:
assert len(results) == 1, "Should find exactly one tag" assert len(results) == 1, "Should find exactly one tag"
assert results[0].name == tag_name, "Tag name should match" assert results[0].name == tag_name, "Tag name should match"
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory):
""" """
Test that missing tag_type or tag_name returns empty list. Test that missing tag_type or tag_name returns empty list.
@ -581,7 +581,7 @@ class TestTagServiceRetrieval:
# Verify no database queries were executed # Verify no database queries were executed
mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input"
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_get_tags_by_target_id(self, mock_db_session, factory): def test_get_tags_by_target_id(self, mock_db_session, factory):
""" """
Test retrieving tags associated with a specific target. Test retrieving tags associated with a specific target.
@ -654,7 +654,7 @@ class TestTagServiceCRUD:
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
@patch("services.tag_service.uuid.uuid4", autospec=True) @patch("services.tag_service.uuid.uuid4", autospec=True)
def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
""" """
@ -743,7 +743,7 @@ class TestTagServiceCRUD:
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
""" """
Test updating a tag name. Test updating a tag name.
@ -795,7 +795,7 @@ class TestTagServiceCRUD:
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_update_tags_raises_error_for_duplicate_name( def test_update_tags_raises_error_for_duplicate_name(
self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory
): ):
@ -827,7 +827,7 @@ class TestTagServiceCRUD:
with pytest.raises(ValueError, match="Tag name already exists"): with pytest.raises(ValueError, match="Tag name already exists"):
TagService.update_tags(args, tag_id="tag-123") TagService.update_tags(args, tag_id="tag-123")
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory):
""" """
Test that updating a non-existent tag raises NotFound. Test that updating a non-existent tag raises NotFound.
@ -859,7 +859,7 @@ class TestTagServiceCRUD:
with pytest.raises(NotFound, match="Tag not found"): with pytest.raises(NotFound, match="Tag not found"):
TagService.update_tags(args, tag_id="nonexistent") TagService.update_tags(args, tag_id="nonexistent")
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_get_tag_binding_count(self, mock_db_session, factory): def test_get_tag_binding_count(self, mock_db_session, factory):
""" """
Test getting the count of bindings for a tag. Test getting the count of bindings for a tag.
@ -895,7 +895,7 @@ class TestTagServiceCRUD:
# Verify count matches expectation # Verify count matches expectation
assert result == expected_count, "Binding count should match" assert result == expected_count, "Binding count should match"
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_delete_tag(self, mock_db_session, factory): def test_delete_tag(self, mock_db_session, factory):
""" """
Test deleting a tag and its bindings. Test deleting a tag and its bindings.
@ -951,7 +951,7 @@ class TestTagServiceCRUD:
# Verify transaction was committed # Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction" mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_delete_tag_raises_not_found(self, mock_db_session, factory): def test_delete_tag_raises_not_found(self, mock_db_session, factory):
""" """
Test that deleting a non-existent tag raises NotFound. Test that deleting a non-existent tag raises NotFound.
@ -999,7 +999,7 @@ class TestTagServiceBindings:
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.check_target_exists", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory):
""" """
Test creating tag bindings. Test creating tag bindings.
@ -1050,7 +1050,7 @@ class TestTagServiceBindings:
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.check_target_exists", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory):
""" """
Test that saving duplicate bindings is idempotent. Test that saving duplicate bindings is idempotent.
@ -1090,7 +1090,7 @@ class TestTagServiceBindings:
mock_db_session.add.assert_not_called(), "Should not create duplicate binding" mock_db_session.add.assert_not_called(), "Should not create duplicate binding"
@patch("services.tag_service.TagService.check_target_exists", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory):
""" """
Test deleting a tag binding. Test deleting a tag binding.
@ -1138,7 +1138,7 @@ class TestTagServiceBindings:
mock_db_session.commit.assert_called_once(), "Should commit transaction" mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.TagService.check_target_exists", autospec=True) @patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory):
""" """
Test that deleting a non-existent binding is a no-op. Test that deleting a non-existent binding is a no-op.
@ -1175,7 +1175,7 @@ class TestTagServiceBindings:
mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete"
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory):
""" """
Test validating that a dataset target exists. Test validating that a dataset target exists.
@ -1216,7 +1216,7 @@ class TestTagServiceBindings:
mock_db_session.query.assert_called_once(), "Should query database for dataset" mock_db_session.query.assert_called_once(), "Should query database for dataset"
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory):
""" """
Test validating that an app target exists. Test validating that an app target exists.
@ -1257,7 +1257,7 @@ class TestTagServiceBindings:
mock_db_session.query.assert_called_once(), "Should query database for app" mock_db_session.query.assert_called_once(), "Should query database for app"
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_check_target_exists_raises_not_found_for_missing_dataset( def test_check_target_exists_raises_not_found_for_missing_dataset(
self, mock_db_session, mock_current_user, factory self, mock_db_session, mock_current_user, factory
): ):
@ -1289,7 +1289,7 @@ class TestTagServiceBindings:
TagService.check_target_exists("knowledge", "nonexistent") TagService.check_target_exists("knowledge", "nonexistent")
@patch("services.tag_service.current_user", autospec=True) @patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True) @patch("services.tag_service.db.session")
def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory):
""" """
Test that missing app raises NotFound. Test that missing app raises NotFound.

View File

@ -59,6 +59,11 @@ def mock_redis():
# Redis is already mocked globally in conftest.py # Redis is already mocked globally in conftest.py
# Reset it for each test # Reset it for each test
redis_client.reset_mock() redis_client.reset_mock()
redis_client.get.reset_mock()
redis_client.setex.reset_mock()
redis_client.delete.reset_mock()
redis_client.lpush.reset_mock()
redis_client.rpop.reset_mock()
redis_client.get.return_value = None redis_client.get.return_value = None
redis_client.setex.return_value = True redis_client.setex.return_value = True
redis_client.delete.return_value = True redis_client.delete.return_value = True